|
@@ -12,7 +12,7 @@ import hivemind
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.utils import nested_compare, nested_flatten, nested_pack, switch_to_uvloop
|
|
|
+from hivemind.utils import as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
|
|
|
@@ -21,10 +21,6 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHand
|
|
|
return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
|
|
|
|
|
|
-async def async_generate(inputs):
|
|
|
- yield inputs
|
|
|
-
|
|
|
-
|
|
|
class RemoteExpert(nn.Module):
|
|
|
"""
|
|
|
A simple module that runs forward/backward of an expert hosted on a remote machine.
|
|
@@ -138,7 +134,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
]
|
|
|
|
|
|
outputs = cls.run_coroutine(
|
|
|
- stub.rpc_forward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
|
|
|
+ stub.rpc_forward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
|
|
|
)
|
|
|
|
|
|
deserialized_outputs = [deserialize_torch_tensor(tensor) for tensor in outputs.tensors]
|
|
@@ -157,7 +153,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
]
|
|
|
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(async_generate(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))),
|
|
|
+ ctx.stub.rpc_backward(as_aiter([runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors)])),
|
|
|
)
|
|
|
|
|
|
deserialized_grad_inputs = [deserialize_torch_tensor(tensor) for tensor in grad_inputs.tensors]
|