|
@@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
-from hivemind.p2p import P2P, PeerID
|
|
|
+from hivemind.p2p import PeerID
|
|
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
import src
|
|
@@ -72,34 +72,63 @@ async def _declare_active_modules(
|
|
|
)
|
|
|
|
|
|
|
|
|
+def get_remote_sequence(
|
|
|
+ dht: DHT,
|
|
|
+ start: int,
|
|
|
+ stop: int,
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
+ return_future: bool = False,
|
|
|
+) -> Union[src.RemoteSequential, MPFuture]:
|
|
|
+ return RemoteExpertWorker.run_coroutine(
|
|
|
+ _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def _get_remote_sequence(
|
|
|
+ dht: DHT,
|
|
|
+ start: int,
|
|
|
+ stop: int,
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
+) -> src.RemoteSequential:
|
|
|
+ uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ manager = src.RemoteSequenceManager(dht, uids, p2p)
|
|
|
+ return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
|
|
+
|
|
|
+
|
|
|
def get_remote_module(
|
|
|
dht: DHT,
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
- expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
return_future: bool = False,
|
|
|
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
|
|
|
+) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
|
|
|
"""
|
|
|
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
|
- :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
|
|
|
+ :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
|
|
:param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
- :returns: a list of [RemoteTransformerBlock if found else None]
|
|
|
+ :returns: a list of [RemoteTransformerBlock]
|
|
|
"""
|
|
|
- single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
- uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
- infos = dht.run_coroutine(
|
|
|
- partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
|
|
|
+ return RemoteExpertWorker.run_coroutine(
|
|
|
+ _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
|
|
|
)
|
|
|
|
|
|
- if return_future:
|
|
|
-
|
|
|
- async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
- p2p = await dht.replicate_p2p()
|
|
|
- modules = _create_remote_modules_from_infos(await infos_future, p2p)
|
|
|
- return modules[0] if single_uid else modules
|
|
|
|
|
|
- return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
- p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
- modules = _create_remote_modules_from_infos(infos, p2p)
|
|
|
+async def _get_remote_module(
|
|
|
+ dht: DHT,
|
|
|
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
|
|
+ single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
+ uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
|
|
+ modules = [
|
|
|
+ src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
|
|
+ ]
|
|
|
return modules[0] if single_uid else modules
|
|
|
|
|
|
|
|
@@ -149,15 +178,3 @@ async def _get_remote_module_infos(
|
|
|
if servers:
|
|
|
modules[i] = RemoteModuleInfo(uid, servers)
|
|
|
return modules
|
|
|
-
|
|
|
-
|
|
|
-def _create_remote_modules_from_infos(
|
|
|
- infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
|
|
-) -> List[Optional[src.RemoteTransformerBlock]]:
|
|
|
- modules: List[Optional[src.RemoteTransformerBlock]] = []
|
|
|
- for info in infos:
|
|
|
- if info is not None:
|
|
|
- modules.append(src.RemoteTransformerBlock(info, p2p))
|
|
|
- else:
|
|
|
- modules.append(None)
|
|
|
- return modules
|