|
@@ -3,11 +3,10 @@ from concurrent.futures import Future
|
|
|
from dataclasses import dataclass
|
|
|
from queue import Queue
|
|
|
from threading import Thread
|
|
|
-from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
|
|
|
+from typing import Any, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
-from multiaddr import Multiaddr
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
import hivemind
|
|
@@ -15,7 +14,6 @@ from hivemind.compression import deserialize_torch_tensor, serialize_torch_tenso
|
|
|
from hivemind.dht import DHT
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils import (
|
|
|
MSGPackSerializer,
|
|
@@ -26,7 +24,7 @@ from hivemind.utils import (
|
|
|
nested_pack,
|
|
|
switch_to_uvloop,
|
|
|
)
|
|
|
-from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
|
|
|
+from hivemind.utils.grpc import gather_from_rpc, split_for_streaming
|
|
|
from hivemind.utils.mpfuture import MPFuture
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
@@ -39,14 +37,7 @@ def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHand
|
|
|
@dataclass(frozen=True)
|
|
|
class RemoteExpertInfo:
|
|
|
uid: str
|
|
|
- peer_id: str
|
|
|
- addrs: Sequence[str]
|
|
|
-
|
|
|
- @property
|
|
|
- def as_peer_info(self) -> Tuple[str, PeerInfo]:
|
|
|
- return self.uid, PeerInfo(
|
|
|
- peer_id=PeerID.from_base58(self.peer_id), addrs=tuple(Multiaddr(a) for a in self.addrs)
|
|
|
- )
|
|
|
+ peer_info: PeerInfo
|
|
|
|
|
|
|
|
|
class RemoteExpert(nn.Module):
|
|
@@ -58,10 +49,18 @@ class RemoteExpert(nn.Module):
|
|
|
:param uid: unique expert identifier
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
|
|
|
+ def __init__(self, expert_info: RemoteExpertInfo, p2p: P2P):
|
|
|
super().__init__()
|
|
|
- self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
|
|
|
- self._info = None
|
|
|
+ self._info, self.p2p = expert_info, p2p
|
|
|
+ self._rpc_info = None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def uid(self):
|
|
|
+ return self._info.uid
|
|
|
+
|
|
|
+ @property
|
|
|
+ def server_peer_info(self):
|
|
|
+ return self._info.peer_info
|
|
|
|
|
|
@property
|
|
|
def stub(self) -> StubBase:
|
|
@@ -86,10 +85,10 @@ class RemoteExpert(nn.Module):
|
|
|
|
|
|
@property
|
|
|
def info(self):
|
|
|
- if self._info is None:
|
|
|
+ if self._rpc_info is None:
|
|
|
outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
- self._info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
- return self._info
|
|
|
+ self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
+ return self._rpc_info
|
|
|
|
|
|
def extra_repr(self):
|
|
|
return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
|
|
@@ -139,8 +138,7 @@ class RemoteExpertWorker:
|
|
|
experts: List[Optional[RemoteExpert]] = []
|
|
|
for i in infos:
|
|
|
if i is not None:
|
|
|
- uid, peer_info = i.as_peer_info
|
|
|
- experts.append(RemoteExpert(uid, peer_info, p2p))
|
|
|
+ experts.append(RemoteExpert(i, p2p))
|
|
|
else:
|
|
|
experts.append(None)
|
|
|
return experts
|
|
@@ -195,16 +193,16 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
ctx.uid, ctx.stub, ctx.info = uid, stub, info
|
|
|
ctx.save_for_backward(*inputs)
|
|
|
|
|
|
- serialized_tensors = [
|
|
|
+ serialized_tensors = (
|
|
|
serialize_torch_tensor(inp, proto.compression)
|
|
|
for inp, proto in zip(inputs, nested_flatten(info["forward_schema"]))
|
|
|
- ]
|
|
|
+ )
|
|
|
|
|
|
size = 0
|
|
|
for t in inputs:
|
|
|
size += t.element_size() * t.nelement()
|
|
|
if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
- deserialized_outputs = cls.forward_partial(serialized_tensors, ctx, stub)
|
|
|
+ deserialized_outputs = cls.forward_stream(serialized_tensors, ctx, stub)
|
|
|
break
|
|
|
else:
|
|
|
deserialized_outputs = cls.forward_oneshot(serialized_tensors, ctx, stub)
|
|
@@ -212,32 +210,27 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
return tuple(deserialized_outputs)
|
|
|
|
|
|
@classmethod
|
|
|
- def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
- split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
|
|
|
+ def forward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
+ split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
- stub.rpc_forward_partial(
|
|
|
+ stub.rpc_forward_stream(
|
|
|
amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(
|
|
|
- uid=ctx.uid,
|
|
|
- tensors=[
|
|
|
- t,
|
|
|
- ],
|
|
|
- ),
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
|
|
|
as_aiter(*split),
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
return RemoteExpertWorker.run_coroutine(
|
|
|
- gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ gather_from_rpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
)
|
|
|
|
|
|
@classmethod
|
|
|
- def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
+ def forward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
|
|
|
outputs = RemoteExpertWorker.run_coroutine(
|
|
|
- stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
+ stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
|
|
|
)
|
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in outputs.tensors]
|
|
@@ -248,16 +241,16 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
|
|
|
inputs_and_grad_outputs = tuple(nested_flatten((ctx.saved_tensors, grad_outputs_cpu)))
|
|
|
backward_schema = tuple(nested_flatten((ctx.info["forward_schema"], ctx.info["outputs_schema"])))
|
|
|
- serialized_tensors = [
|
|
|
+ serialized_tensors = (
|
|
|
serialize_torch_tensor(tensor, proto.compression)
|
|
|
for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
|
|
|
- ]
|
|
|
+ )
|
|
|
|
|
|
size = 0
|
|
|
for t in inputs_and_grad_outputs:
|
|
|
size += t.element_size() * t.nelement()
|
|
|
if size >= DEFAULT_MAX_MSG_SIZE:
|
|
|
- deserialized_grad_inputs = cls.backward_partial(serialized_tensors, ctx)
|
|
|
+ deserialized_grad_inputs = cls.backward_stream(serialized_tensors, ctx)
|
|
|
break
|
|
|
else:
|
|
|
deserialized_grad_inputs = cls.backward_oneshot(serialized_tensors, ctx)
|
|
@@ -266,32 +259,27 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
- def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
+ def backward_stream(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
- ctx.stub.rpc_backward_partial(
|
|
|
+ ctx.stub.rpc_backward_stream(
|
|
|
amap_in_executor(
|
|
|
- lambda t: runtime_pb2.ExpertRequest(
|
|
|
- uid=ctx.uid,
|
|
|
- tensors=[
|
|
|
- t,
|
|
|
- ],
|
|
|
- ),
|
|
|
+ lambda t: runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=[t]),
|
|
|
as_aiter(*split),
|
|
|
),
|
|
|
)
|
|
|
)
|
|
|
|
|
|
return RemoteExpertWorker.run_coroutine(
|
|
|
- gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ gather_from_rpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
)
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
- def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
+ def backward_oneshot(cls, serialized_tensors: Iterable[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
- ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
+ ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=list(serialized_tensors)))
|
|
|
)
|
|
|
|
|
|
return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
|