|
@@ -8,7 +8,8 @@ from functools import partial
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
from hivemind.dht import DHT, DHTNode, DHTValue
|
|
|
-from hivemind.p2p import P2P, PeerID
|
|
|
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
+from hivemind.p2p import PeerID
|
|
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
import src
|
|
@@ -71,6 +72,40 @@ async def _declare_active_modules(
|
|
|
)
|
|
|
|
|
|
|
|
|
+def get_remote_module(
|
|
|
+ dht: DHT,
|
|
|
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
+ return_future: bool = False,
|
|
|
+) -> List[src.RemoteTransformerBlock]:
|
|
|
+ """
|
|
|
+ :param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
|
+ :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
|
|
+ :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
+ :returns: a list of [RemoteTransformerBlock]
|
|
|
+ """
|
|
|
+ return RemoteExpertWorker.run_coroutine(
|
|
|
+ _get_distinct_blocks(dht, uid_or_uids, config, dht_prefix), return_future=return_future
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+async def _get_distinct_blocks(
|
|
|
+ dht: DHT,
|
|
|
+ uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
+ config: src.DistributedBloomConfig,
|
|
|
+ dht_prefix: Optional[str] = None,
|
|
|
+):
|
|
|
+ single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
+ uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
+ p2p = await dht.replicate_p2p()
|
|
|
+ managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
|
|
+ modules = [
|
|
|
+ src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
|
|
+ ]
|
|
|
+ return modules[0] if single_uid else modules
|
|
|
+
|
|
|
+
|
|
|
def get_remote_module_infos(
|
|
|
dht: DHT,
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|