|
@@ -1,16 +1,22 @@
|
|
|
+from dataclasses import dataclass
|
|
|
from concurrent.futures import Future
|
|
|
+from lib2to3.pgen2.token import OP
|
|
|
+from multiaddr import Multiaddr
|
|
|
+import os
|
|
|
from queue import Queue
|
|
|
from threading import Thread
|
|
|
-from typing import Any, Awaitable, Dict, List, Optional, Tuple
|
|
|
+from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
import hivemind
|
|
|
+from hivemind.dht import DHT
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
from hivemind.utils import (
|
|
|
MSGPackSerializer,
|
|
@@ -22,6 +28,7 @@ from hivemind.utils import (
|
|
|
switch_to_uvloop,
|
|
|
)
|
|
|
from hivemind.utils.grpc import gather_from_grpc, split_for_streaming
|
|
|
+from hivemind.utils.mpfuture import MPFuture
|
|
|
|
|
|
DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autograd in RemoteExpert
|
|
|
|
|
@@ -29,6 +36,19 @@ DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autogra
|
|
|
def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
|
|
|
return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
|
|
|
+@dataclass(frozen=True)
|
|
|
+class RemoteExpertInfo:
|
|
|
+ uid: str
|
|
|
+ peer_id: str
|
|
|
+ addrs: Sequence[str]
|
|
|
+
|
|
|
+ @property
|
|
|
+ def as_peer_info(self) -> Tuple[str, PeerInfo]:
|
|
|
+ return self.uid, PeerInfo(
|
|
|
+ peer_id=PeerID.from_base58(self.peer_id),
|
|
|
+ addrs=tuple(Multiaddr(a) for a in self.addrs)
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
class RemoteExpert(nn.Module):
|
|
|
"""
|
|
@@ -39,17 +59,11 @@ class RemoteExpert(nn.Module):
|
|
|
:param uid: unique expert identifier
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, uid, server_peer_info: PeerInfo, p2p: Optional[P2P] = None, connect: bool = True):
|
|
|
+ def __init__(self, uid, server_peer_info: PeerInfo, p2p: P2P):
|
|
|
super().__init__()
|
|
|
- self.uid, self.server_peer_info = uid, server_peer_info
|
|
|
+ self.uid, self.server_peer_info, self.p2p = uid, server_peer_info, p2p
|
|
|
self._info = None
|
|
|
|
|
|
- if p2p is None:
|
|
|
- self.p2p = _RemoteModuleCall.run_coroutine(P2P.create())
|
|
|
- _RemoteModuleCall.run_coroutine(self.p2p._client.connect(server_peer_info.peer_id, server_peer_info.addrs))
|
|
|
- else:
|
|
|
- self.p2p = p2p
|
|
|
-
|
|
|
@property
|
|
|
def stub(self) -> StubBase:
|
|
|
return _get_expert_stub(self.p2p, self.server_peer_info)
|
|
@@ -74,7 +88,7 @@ class RemoteExpert(nn.Module):
|
|
|
@property
|
|
|
def info(self):
|
|
|
if self._info is None:
|
|
|
- outputs = _RemoteModuleCall.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(self.stub.rpc_info(runtime_pb2.ExpertUID(uid=self.uid)))
|
|
|
self._info = MSGPackSerializer.loads(outputs.serialized_info)
|
|
|
return self._info
|
|
|
|
|
@@ -82,11 +96,13 @@ class RemoteExpert(nn.Module):
|
|
|
return f"uid={self.uid}, server_peer_info={self.server_peer_info}"
|
|
|
|
|
|
|
|
|
-class _RemoteModuleCall(torch.autograd.Function):
|
|
|
- """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
|
|
|
+class RemoteExpertWorker:
|
|
|
+ """Local thread for managing async tasks related to RemoteExpert"""
|
|
|
|
|
|
_task_queue: Queue = Queue()
|
|
|
_event_thread: Optional[Thread] = None
|
|
|
+ _pid: int = 0
|
|
|
+
|
|
|
|
|
|
@classmethod
|
|
|
def _run(cls):
|
|
@@ -106,7 +122,8 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
|
|
|
@classmethod
|
|
|
def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
|
|
|
- if cls._event_thread is None:
|
|
|
+ if cls._event_thread is None or cls._pid != os.getpid():
|
|
|
+ cls._pid = os.getpid()
|
|
|
cls._event_thread = Thread(target=cls._run, daemon=True)
|
|
|
cls._event_thread.start()
|
|
|
|
|
@@ -119,6 +136,29 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
result = future.result()
|
|
|
return result
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def spawn_experts_future(cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT) -> MPFuture[List[Optional[RemoteExpert]]]:
|
|
|
+ async def _unpack():
|
|
|
+ return cls.spawn_experts(await infos, dht)
|
|
|
+ return cls.run_coroutine(_unpack, True)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def spawn_experts(cls, infos: Sequence[Optional[RemoteExpertInfo]], dht: DHT) -> List[Optional[RemoteExpert]]:
|
|
|
+ p2p = cls.run_coroutine(dht.replicate_p2p())
|
|
|
+ experts: List[Optional[RemoteExpert]] = []
|
|
|
+ for i in infos:
|
|
|
+ if i is not None:
|
|
|
+ uid, peer_info = i.as_peer_info
|
|
|
+ experts.append(RemoteExpert(uid, peer_info, p2p))
|
|
|
+ else:
|
|
|
+ experts.append(None)
|
|
|
+ return experts
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+class _RemoteModuleCall(torch.autograd.Function):
|
|
|
+ """Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
|
|
|
+
|
|
|
@classmethod
|
|
|
def forward(
|
|
|
cls,
|
|
@@ -155,7 +195,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
def forward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
split = [p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2)]
|
|
|
|
|
|
- outputs = cls.run_coroutine(
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(
|
|
|
stub.rpc_forward_partial(
|
|
|
amap_in_executor(
|
|
|
lambda t: runtime_pb2.ExpertRequest(
|
|
@@ -169,12 +209,12 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return cls.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
+ return RemoteExpertWorker.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
|
|
|
@classmethod
|
|
|
def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
|
|
|
|
- outputs = cls.run_coroutine(
|
|
|
+ outputs = RemoteExpertWorker.run_coroutine(
|
|
|
stub.rpc_forward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
)
|
|
|
|
|
@@ -207,7 +247,7 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
def backward_partial(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
split = tuple(p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE // 2))
|
|
|
|
|
|
- grad_inputs = cls.run_coroutine(
|
|
|
+ grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
ctx.stub.rpc_backward_partial(
|
|
|
amap_in_executor(
|
|
|
lambda t: runtime_pb2.ExpertRequest(
|
|
@@ -221,12 +261,12 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return cls.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
+ return RemoteExpertWorker.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|
|
|
def backward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx) -> List[torch.Tensor]:
|
|
|
- grad_inputs = cls.run_coroutine(
|
|
|
+ grad_inputs = RemoteExpertWorker.run_coroutine(
|
|
|
ctx.stub.rpc_backward(runtime_pb2.ExpertRequest(uid=ctx.uid, tensors=serialized_tensors))
|
|
|
)
|
|
|
|