Browse Source

fix iterators

Denis Mazur 4 years ago
parent
commit
bcd4e85058
2 changed files with 13 additions and 12 deletions
  1. 10 10
      hivemind/moe/client/expert.py
  2. 3 2
      hivemind/moe/server/connection_handler.py

+ 10 - 10
hivemind/moe/client/expert.py

@@ -133,11 +133,11 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
 
-        async def func():
-            stream = stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
-            return await stream.__anext__()
-
-        outputs = cls.run_coroutine(func())
+        outputs = cls.run_coroutine(
+            asingle(
+                stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            ),
+        )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
@@ -154,11 +154,11 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
 
-        async def func():
-            stream = ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
-            return await asingle(stream)
-
-        grad_inputs = cls.run_coroutine(func())
+        grad_inputs = cls.run_coroutine(
+            asingle(
+                ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            ),
+        )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 3 - 2
hivemind/moe/server/connection_handler.py

@@ -60,13 +60,14 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
     ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
+
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         serialized_response = [
             serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
         ]
 
-        return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
+        yield runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def rpc_backward(
         self, request: runtime_pb2.ExpertRequest, context: P2PContext
@@ -77,4 +78,4 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
         ]
-        return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
+        yield runtime_pb2.ExpertResponse(tensors=serialized_response)