|
@@ -1,19 +1,19 @@
|
|
|
-from dataclasses import dataclass
|
|
|
+import os
|
|
|
from concurrent.futures import Future
|
|
|
+from dataclasses import dataclass
|
|
|
from lib2to3.pgen2.token import OP
|
|
|
-from multiaddr import Multiaddr
|
|
|
-import os
|
|
|
from queue import Queue
|
|
|
from threading import Thread
|
|
|
-from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
+from typing import Any, Awaitable, Dict, List, Optional, Sequence, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from multiaddr import Multiaddr
|
|
|
from torch.autograd.function import once_differentiable
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.dht import DHT
|
|
|
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
|
|
|
+from hivemind.dht import DHT
|
|
|
from hivemind.p2p import P2P, PeerInfo, StubBase
|
|
|
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
|
|
|
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
|
|
@@ -36,6 +36,7 @@ DUMMY = torch.empty(0, requires_grad=True) # dummy tensor that triggers autogra
|
|
|
def _get_expert_stub(p2p: P2P, server_peer_info: PeerInfo): # -> ConnectionHandlerStub:
|
|
|
return hivemind.moe.server.connection_handler.ConnectionHandler.get_stub(p2p, server_peer_info.peer_id)
|
|
|
|
|
|
+
|
|
|
@dataclass(frozen=True)
|
|
|
class RemoteExpertInfo:
|
|
|
uid: str
|
|
@@ -45,8 +46,7 @@ class RemoteExpertInfo:
|
|
|
@property
|
|
|
def as_peer_info(self) -> Tuple[str, PeerInfo]:
|
|
|
return self.uid, PeerInfo(
|
|
|
- peer_id=PeerID.from_base58(self.peer_id),
|
|
|
- addrs=tuple(Multiaddr(a) for a in self.addrs)
|
|
|
+ peer_id=PeerID.from_base58(self.peer_id), addrs=tuple(Multiaddr(a) for a in self.addrs)
|
|
|
)
|
|
|
|
|
|
|
|
@@ -103,7 +103,6 @@ class RemoteExpertWorker:
|
|
|
_event_thread: Optional[Thread] = None
|
|
|
_pid: int = 0
|
|
|
|
|
|
-
|
|
|
@classmethod
|
|
|
def _run(cls):
|
|
|
loop = switch_to_uvloop()
|
|
@@ -137,9 +136,12 @@ class RemoteExpertWorker:
|
|
|
return result
|
|
|
|
|
|
@classmethod
|
|
|
- def spawn_experts_future(cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT) -> MPFuture[List[Optional[RemoteExpert]]]:
|
|
|
+ def spawn_experts_future(
|
|
|
+ cls, infos: MPFuture[Sequence[Optional[RemoteExpertInfo]]], dht: DHT
|
|
|
+ ) -> MPFuture[List[Optional[RemoteExpert]]]:
|
|
|
async def _unpack():
|
|
|
return cls.spawn_experts(await infos, dht)
|
|
|
+
|
|
|
return cls.run_coroutine(_unpack, True)
|
|
|
|
|
|
@classmethod
|
|
@@ -155,7 +157,6 @@ class RemoteExpertWorker:
|
|
|
return experts
|
|
|
|
|
|
|
|
|
-
|
|
|
class _RemoteModuleCall(torch.autograd.Function):
|
|
|
"""Internal autograd-friendly call of a remote module. For applications, use RemoteExpert instead."""
|
|
|
|
|
@@ -209,7 +210,9 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return RemoteExpertWorker.run_coroutine(gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
+ return RemoteExpertWorker.run_coroutine(
|
|
|
+ gather_from_grpc(outputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
def forward_oneshot(cls, serialized_tensors: List[runtime_pb2.Tensor], ctx, stub) -> List[torch.Tensor]:
|
|
@@ -261,7 +264,9 @@ class _RemoteModuleCall(torch.autograd.Function):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- return RemoteExpertWorker.run_coroutine(gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor))
|
|
|
+ return RemoteExpertWorker.run_coroutine(
|
|
|
+ gather_from_grpc(grad_inputs, lambda r: r.tensors, deserialize_torch_tensor)
|
|
|
+ )
|
|
|
|
|
|
@classmethod
|
|
|
@once_differentiable
|