Kaynağa Gözat

rewrite to async iterators

Denis Mazur 4 yıl önce
ebeveyn
işleme
582d6dc294

+ 5 - 1
hivemind/moe/client/expert.py

@@ -21,6 +21,10 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHand
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
+async def async_generate(inputs):
+    yield inputs
+
+
 class RemoteExpert(nn.Module):
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
@@ -134,7 +138,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         outputs = cls.run_coroutine(
-            stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)),
+            stub.rpc_forward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
         )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]

+ 9 - 3
hivemind/moe/server/connection_handler.py

@@ -1,7 +1,7 @@
 import asyncio
 import multiprocessing as mp
 import pickle
-from typing import Dict
+from typing import AsyncIterator, Dict
 
 import torch
 
@@ -56,7 +56,11 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def rpc_info(self, request: runtime_pb2.ExpertUID, context: P2PContext) -> runtime_pb2.ExpertInfo:
         return runtime_pb2.ExpertInfo(serialized_info=pickle.dumps(self.experts[request.uid].get_info()))
 
-    async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
+    async def rpc_forward(
+        self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
+    ) -> runtime_pb2.ExpertResponse:
+        request = await stream.__anext__()
+
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         serialized_response = [
@@ -67,8 +71,10 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def rpc_backward(
-        self, request: runtime_pb2.ExpertRequest, context: P2PContext
+        self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> runtime_pb2.ExpertResponse:
+        request = await stream.__anext__()
+
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         serialized_response = [