|
@@ -2,6 +2,7 @@ from concurrent.futures import Future
|
|
from functools import partial
|
|
from functools import partial
|
|
from typing import List, Optional, Union, Sequence
|
|
from typing import List, Optional, Union, Sequence
|
|
|
|
|
|
|
|
+import torch
|
|
from hivemind.moe.client import RemoteExpert
|
|
from hivemind.moe.client import RemoteExpert
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
from hivemind.moe.expert_uid import ExpertUID
|
|
@@ -10,18 +11,30 @@ from hivemind.p2p import StubBase, P2P
|
|
from hivemind.proto.runtime_pb2 import ExpertInfo
|
|
from hivemind.proto.runtime_pb2 import ExpertInfo
|
|
from hivemind.dht import DHT
|
|
from hivemind.dht import DHT
|
|
from hivemind.utils import MPFuture, DHTExpiration
|
|
from hivemind.utils import MPFuture, DHTExpiration
|
|
|
|
+
|
|
|
|
+from src import DistributedBloomConfig
|
|
|
|
+from src.server.backend import MAX_LENGTH
|
|
from src.server.handler import TransformerConnectionHandler
|
|
from src.server.handler import TransformerConnectionHandler
|
|
|
|
|
|
|
|
|
|
-class RemoteTransformerBlock(RemoteExpert):
|
|
|
|
|
|
+class RemoteTransformerBlockSession(RemoteExpert):
|
|
|
|
+ """A class that interacts with a specific remote server for forward/backward or inference"""
|
|
|
|
+
|
|
|
|
+ def __init__(self, config: DistributedBloomConfig, info: ExpertInfo, p2p: P2P):
|
|
|
|
+ super().__init__(info, p2p)
|
|
|
|
+ self._config = config
|
|
|
|
+ self._inputs_cache = torch.empty(1, MAX_LENGTH, config.hidden_size, dtype=config.dtype)
|
|
|
|
+ self._active_stream: Optional[RemoteTransformerStream] = None
|
|
|
|
+
|
|
@property
|
|
@property
|
|
def stub(self) -> StubBase:
|
|
def stub(self) -> StubBase:
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
def get_remote_module(
|
|
def get_remote_module(
|
|
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
|
|
dht: DHT, uids: List[ExpertUID], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
|
|
-) -> Union[List[Optional[RemoteTransformerBlock]], MPFuture[List[Optional[RemoteTransformerBlock]]]]:
|
|
|
|
|
|
+) -> Union[List[Optional[RemoteTransformerBlockSession]], MPFuture[List[Optional[RemoteTransformerBlockSession]]]]:
|
|
"""
|
|
"""
|
|
:param uids: find experts with these ids from across the DHT
|
|
:param uids: find experts with these ids from across the DHT
|
|
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
:param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
|
|
@@ -35,7 +48,7 @@ def get_remote_module(
|
|
|
|
|
|
def create_remote_module(
|
|
def create_remote_module(
|
|
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
infos: Union[Sequence[Optional[ExpertInfo]], MPFuture], dht: DHT, return_future: bool = False
|
|
-) -> Union[List[Optional[RemoteTransformerBlock]], Future]:
|
|
|
|
|
|
+) -> Union[List[Optional[RemoteTransformerBlockSession]], Future]:
|
|
if return_future:
|
|
if return_future:
|
|
|
|
|
|
async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
@@ -48,10 +61,10 @@ def create_remote_module(
|
|
|
|
|
|
|
|
|
|
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
|
|
def _create_remote_experts(infos: Sequence[Optional[ExpertInfo]], p2p: P2P) -> List[Optional[RemoteTransformerBlock]]:
|
|
- experts: List[Optional[RemoteTransformerBlock]] = []
|
|
|
|
|
|
+ experts: List[Optional[RemoteTransformerBlockSession]] = []
|
|
for info in infos:
|
|
for info in infos:
|
|
if info is not None:
|
|
if info is not None:
|
|
- experts.append(RemoteTransformerBlock(info, p2p))
|
|
|
|
|
|
+ experts.append(RemoteTransformerBlockSession(info, p2p))
|
|
else:
|
|
else:
|
|
experts.append(None)
|
|
experts.append(None)
|
|
return experts
|
|
return experts
|