|
@@ -8,7 +8,6 @@ from functools import partial
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
|
|
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import P2P, PeerID
|
|
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
|
|
|
|
@@ -72,37 +71,6 @@ async def _declare_active_modules(
|
|
|
)
|
|
|
|
|
|
|
|
|
-def get_remote_module(
|
|
|
- dht: DHT,
|
|
|
- uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
- expiration_time: Optional[DHTExpiration] = None,
|
|
|
- return_future: bool = False,
|
|
|
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
|
|
|
- """
|
|
|
- :param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
|
- :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
|
|
|
- :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
- :returns: a list of [RemoteTransformerBlock if found else None]
|
|
|
- """
|
|
|
- single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
- uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
- infos = dht.run_coroutine(
|
|
|
- partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
|
|
|
- )
|
|
|
-
|
|
|
- if return_future:
|
|
|
-
|
|
|
- async def _unpack(infos_future: MPFuture, dht: DHT):
|
|
|
- p2p = await dht.replicate_p2p()
|
|
|
- modules = _create_remote_modules_from_infos(await infos_future, p2p)
|
|
|
- return modules[0] if single_uid else modules
|
|
|
-
|
|
|
- return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
|
|
|
- p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
|
|
|
- modules = _create_remote_modules_from_infos(infos, p2p)
|
|
|
- return modules[0] if single_uid else modules
|
|
|
-
|
|
|
-
|
|
|
def get_remote_module_infos(
|
|
|
dht: DHT,
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
@@ -149,15 +117,3 @@ async def _get_remote_module_infos(
|
|
|
if servers:
|
|
|
modules[i] = RemoteModuleInfo(uid, servers)
|
|
|
return modules
|
|
|
-
|
|
|
-
|
|
|
-def _create_remote_modules_from_infos(
|
|
|
- infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
|
|
|
-) -> List[Optional[src.RemoteTransformerBlock]]:
|
|
|
- modules: List[Optional[src.RemoteTransformerBlock]] = []
|
|
|
- for info in infos:
|
|
|
- if info is not None:
|
|
|
- modules.append(src.RemoteTransformerBlock(info, p2p))
|
|
|
- else:
|
|
|
- modules.append(None)
|
|
|
- return modules
|