|
@@ -24,7 +24,7 @@ from hivemind.utils import (
|
|
|
nested_pack,
|
|
|
)
|
|
|
from hivemind.utils.mpfuture import MPFuture
|
|
|
-from hivemind.utils.streaming import gather_from_streaming, split_for_streaming
|
|
|
+from hivemind.utils.streaming import combine_and_deserialize_from_streaming, split_for_streaming
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
|
|
|
@@ -35,6 +35,8 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo) -> "ConnectionHandler
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class RemoteExpertInfo:
|
|
|
+ """A simple data class containing uid of expert and server PeerInfo"""
|
|
|
+
|
|
|
uid: str
|
|
|
peer_info: PeerInfo
|
|
|
|
|
@@ -45,7 +47,9 @@ class RemoteExpert(nn.Module):
|
|
|
Works seamlessly with pytorch autograd. (this is essentially a simple RPC function)
|
|
|
Warning: RemoteExpert currently assumes that you provide it with correct input shapes.
|
|
|
Sending wrong input shapes can cause RemoteExpert to freeze indefinitely due to error in runtime.
|
|
|
- :param uid: unique expert identifier
|
|
|
+
|
|
|
+ :param expert_info: RemoteExpertInfo with uid and server PeerInfo
|
|
|
+ :param p2p: P2P instance connected to the running p2pd
|
|
|
"""
|
|
|
|
|
|
def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
|
|
@@ -135,7 +139,7 @@ def batch_create_remote_experts(
|
|
|
|
|
|
|
|
|
async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
|
|
|
|
|
grad_inputs = await stub.rpc_backward_stream(
|
|
|
amap_in_executor(
|
|
@@ -143,8 +147,8 @@ async def _backward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Te
|
|
|
iter_as_aiter(split),
|
|
|
),
|
|
|
)
|
|
|
-
|
|
|
- return await gather_from_streaming(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
|
|
|
+ return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
|
|
|
|
|
|
|
|
|
async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
@@ -155,23 +159,19 @@ async def _backward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
|
|
|
|
|
|
|
|
|
async def expert_backward(
|
|
|
- uid: str, inputs_and_grads: Sequence[torch.Tensor], compressions: Iterable, stub
|
|
|
+ uid: str, inputs_and_grads: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
) -> List[torch.Tensor]:
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs_and_grads, compressions)
|
|
|
- )
|
|
|
-
|
|
|
size = 0
|
|
|
for t in inputs_and_grads:
|
|
|
size += t.element_size() * t.nelement()
|
|
|
- if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ if size > DEFAULT_MAX_MSG_SIZE:
|
|
|
return await _backward_stream(uid, serialized_tensors, stub)
|
|
|
else:
|
|
|
return await _backward_unary(uid, serialized_tensors, stub)
|
|
|
|
|
|
|
|
|
async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
|
- split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
+ split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
|
|
|
|
|
|
outputs = await stub.rpc_forward_stream(
|
|
|
amap_in_executor(
|
|
@@ -180,7 +180,8 @@ async def _forward_stream(uid: str, serialized_tensors: Iterable[runtime_pb2.Ten
|
|
|
),
|
|
|
)
|
|
|
|
|
|
- return await gather_from_streaming(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
|
|
|
+ return await combine_and_deserialize_from_streaming(tensors_stream, deserialize_torch_tensor)
|
|
|
|
|
|
|
|
|
async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub) -> List[torch.Tensor]:
|
|
@@ -190,14 +191,13 @@ async def _forward_unary(uid: str, serialized_tensors: Iterable[runtime_pb2.Tens
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
|
|
|
|
|
|
|
-async def expert_forward(uid: str, inputs: Sequence[torch.Tensor], compressions: Iterable, stub) -> List[torch.Tensor]:
|
|
|
- serialized_tensors = (
|
|
|
- serialize_torch_tensor(tensor, compression) for tensor, compression in zip(inputs, compressions)
|
|
|
- )
|
|
|
+async def expert_forward(
|
|
|
+ uid: str, inputs: Sequence[torch.Tensor], serialized_tensors: Iterable[runtime_pb2.Tensor], stub
|
|
|
+) -> List[torch.Tensor]:
|
|
|
size = 0
|
|
|
for t in inputs:
|
|
|
size += t.element_size() * t.nelement()
|
|
|
- if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
+ if size > DEFAULT_MAX_MSG_SIZE:
|
|
|
return await _forward_stream(uid, serialized_tensors, stub)
|
|
|
else:
|
|
|
return await _forward_unary(uid, serialized_tensors, stub)
|
|
@@ -220,10 +220,11 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
inputs = tuple(tensor.cpu().detach() for tensor in inputs)
|
|
|
ctx.uid, ctx.stub, ctx.info = uid, stub, info
|
|
|
ctx.save_for_backward(*inputs)
|
|
|
-
|
|
|
- deserialized_outputs = _RemoteExpertWorker.run_coroutine(
|
|
|
- expert_forward(uid, inputs, (p.compression for p in nested_flatten(info["forward_schema"])), stub)
|
|
|
+ serialized_tensors = (
|
|
|
+ serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
)
|
|
|
+ deserialized_outputs = _RemoteExpertWorker.run_coroutine(expert_forward(uid, inputs, serialized_tensors, stub))
|
|
|
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
@@ -233,9 +234,12 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
-
|
|
|
+ serialized_tensors = (
|
|
|
+ serialize_torch_tensor(tensor, proto.compression)
|
|
|
+ for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
+ )
|
|
|
deserialized_grad_inputs = _RemoteExpertWorker.run_coroutine(
|
|
|
- expert_backward(ctx.uid, inputs_and_grad_outputs, (p.compression for p in backward_schema), ctx.stub)
|
|
|
+ expert_backward(ctx.uid, inputs_and_grad_outputs, serialized_tensors, ctx.stub)
|
|
|
)
|
|
|
|
|
|
return (DUMMY, None, None, None, *deserialized_grad_inputs)
|