Browse Source

reverse stream methods in server

Denis Mazur 4 years ago
parent
commit
08a9c02e60
2 changed files with 18 additions and 18 deletions
  1. 11 7
      hivemind/moe/client/expert.py
  2. 7 11
      hivemind/moe/server/connection_handler.py

+ 11 - 7
hivemind/moe/client/expert.py

@@ -12,7 +12,7 @@ import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils import asingle, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -133,9 +133,11 @@ class _RemoteModuleCall(torch.autograd.Function):
             for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
         ]
 
-        outputs = cls.run_coroutine(
-            stub.rpc_forward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
-        )
+        async def func():
+            stream = stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            return await stream.__anext__()
+
+        outputs = cls.run_coroutine(func())
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
 
@@ -152,9 +154,11 @@ class _RemoteModuleCall(torch.autograd.Function):
             for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
         ]
 
-        grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_backward(as_aiter(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
-        )
+        async def func():
+            stream = ctx.stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
+            return await asingle(stream)
+
+        grad_inputs = cls.run_coroutine(func())
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
         return (DUMMY, None, None, None, *deserialized_grad_inputs)

+ 7 - 11
hivemind/moe/server/connection_handler.py

@@ -10,7 +10,7 @@ from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MPFuture, asingle, get_logger, nested_flatten
+from hivemind.utils import MPFuture, as_aiter, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 
 logger = get_logger(__name__)
@@ -57,10 +57,8 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
 
     async def rpc_forward(
-        self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
-        request = await asingle(stream)
-
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         serialized_response = [
@@ -68,17 +66,15 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].outputs_schema))
         ]
 
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+        return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))
 
     async def rpc_backward(
-        self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
-    ) -> runtime_pb2.ExpertResponse:
-        request = await asingle(stream)
-
+        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+    ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [
             serialize_torch_tensor(tensor, proto.compression, allow_inplace=True)
             for tensor, proto in zip(await future, nested_flatten(self.experts[request.uid].grad_inputs_schema))
         ]
-        return runtime_pb2.ExpertResponse(tensors=serialized_response)
+        return as_aiter(runtime_pb2.ExpertResponse(tensors=serialized_response))