|
@@ -12,7 +12,15 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
-from hivemind.utils import MSGPackSerializer, amap_in_executor, as_aiter, nested_compare, nested_flatten, nested_pack, switch_to_uvloop
|
|
|
+from hivemind.utils import (
|
|
|
+ MSGPackSerializer,
|
|
|
+ amap_in_executor,
|
|
|
+ as_aiter,
|
|
|
+ nested_compare,
|
|
|
+ nested_flatten,
|
|
|
+ nested_pack,
|
|
|
+ switch_to_uvloop
|
|
|
+)
|
|
|
from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
@@ -144,28 +152,27 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
|
@classmethod
|
|
|
- def forward_partial(
|
|
|
- cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
|
|
|
- ) -> List[torch.Tensor]:
|
|
|
+ def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
|
|
|
|
|
|
outputs = cls.run_coroutine(
|
|
|
stub.rpc_forward_partial(
|
|
|
amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
|
|
|
- as_aiter(*split)
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(
|
|
|
+ uid=ctx.uid,
|
|
|
+ tensors=[
|
|
|
+ t,
|
|
|
+ ],
|
|
|
+ ),
|
|
|
+ as_aiter(*split),
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return cls.run_coroutine(
|
|
|
- gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
- )
|
|
|
+ return cls.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
|
|
|
@classmethod
|
|
|
- def forward_oneshot(
|
|
|
- cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub
|
|
|
- ) -> List[torch.Tensor]:
|
|
|
+ def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
|
|
|
outputs = cls.run_coroutine(
|
|
|
stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
@@ -197,29 +204,28 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
- def backward_partial(
|
|
|
- cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
|
|
|
- ) -> List[torch.Tensor]:
|
|
|
+ def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
ctx.stub.rpc_backward_partial(
|
|
|
amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t, ]),
|
|
|
- as_aiter(*split)
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(
|
|
|
+ uid=ctx.uid,
|
|
|
+ tensors=[
|
|
|
+ t,
|
|
|
+ ],
|
|
|
+ ),
|
|
|
+ as_aiter(*split),
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return cls.run_coroutine(
|
|
|
- gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
- )
|
|
|
+ return cls.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
- def backward_oneshot(
|
|
|
- cls, serialized_tensors: List[runtime_pb2.Tensor], ctx
|
|
|
- ) -> List[torch.Tensor]:
|
|
|
+ def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
grad_inputs = cls.run_coroutine(
|
|
|
ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
)
|