Denis Mazur 4 tahun lalu
induk
melakukan
1a0fd0f202
2 mengubah file dengan 6 tambahan dan 10 penghapusan
  1. 3 7
      hivemind/moe/client/expert.py
  2. 3 3
      hivemind/moe/server/connection_handler.py

+ 3 - 7
hivemind/moe/client/expert.py

@@ -12,7 +12,7 @@ import hivemind
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, PeerInfo, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
+from hivemind.utils import as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
 
 DUMMY = torch.empty(0, requires_grad=True)  # dummy tensor that triggers autograd in RemoteExpert
 
@@ -21,10 +21,6 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo):  # -> ConnectionHand
     return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
 
 
-async def async_generate(inputs):
-    yield inputs
-
-
 class RemoteExpert(nn.Module):
     """
     A simple module that runs forward/backward of an expert hosted on a remote machine.
@@ -138,7 +134,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         outputs = cls.run_coroutine(
-            stub.rpc_forward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
+            stub.rpc_forward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
         )
 
         deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
@@ -157,7 +153,7 @@ class _RemoteModuleCall(torch.autograd.Function):
         ]
 
         grad_inputs = cls.run_coroutine(
-            ctx.stub.rpc_backward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
+            ctx.stub.rpc_backward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
         )
 
         deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]

+ 3 - 3
hivemind/moe/server/connection_handler.py

@@ -10,7 +10,7 @@ from hivemind.dht import DHT
 from hivemind.moe.server.expert_backend import ExpertBackend
 from hivemind.p2p import P2PContext, ServicerBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import MPFuture, get_logger, nested_flatten
+from hivemind.utils import MPFuture, asingle, get_logger, nested_flatten
 from hivemind.utils.asyncio import switch_to_uvloop
 
 logger = get_logger(__name__)
@@ -59,7 +59,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def rpc_forward(
         self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> runtime_pb2.ExpertResponse:
-        request = await stream.__anext__()
+        request = await asingle(stream)
 
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
@@ -73,7 +73,7 @@ class ConnectionHandler(mp.context.ForkProcess, ServicerBase):
     async def rpc_backward(
         self, stream: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
     ) -> runtime_pb2.ExpertResponse:
-        request = await stream.__anext__()
+        request = await asingle(stream)
 
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)