|
@@ -12,8 +12,8 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
from hivemind.p2p import PeerID
|
|
|
from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
|
|
|
|
|
|
-import src
|
|
|
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
|
|
+import petals
|
|
|
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
@@ -76,10 +76,10 @@ def get_remote_sequence(
|
|
|
dht: DHT,
|
|
|
start: int,
|
|
|
stop: int,
|
|
|
- config: src.DistributedBloomConfig,
|
|
|
+ config: petals.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
return_future: bool = False,
|
|
|
-) -> Union[src.RemoteSequential, MPFuture]:
|
|
|
+) -> Union[petals.RemoteSequential, MPFuture]:
|
|
|
return RemoteExpertWorker.run_coroutine(
|
|
|
_get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
|
|
|
)
|
|
@@ -89,22 +89,22 @@ async def _get_remote_sequence(
|
|
|
dht: DHT,
|
|
|
start: int,
|
|
|
stop: int,
|
|
|
- config: src.DistributedBloomConfig,
|
|
|
+ config: petals.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
-) -> src.RemoteSequential:
|
|
|
+) -> petals.RemoteSequential:
|
|
|
uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
|
|
|
p2p = await dht.replicate_p2p()
|
|
|
- manager = src.RemoteSequenceManager(dht, uids, p2p)
|
|
|
- return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
|
|
+ manager = petals.RemoteSequenceManager(dht, uids, p2p)
|
|
|
+ return petals.RemoteSequential(config, dht, dht_prefix, p2p, manager)
|
|
|
|
|
|
|
|
|
def get_remote_module(
|
|
|
dht: DHT,
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
- config: src.DistributedBloomConfig,
|
|
|
+ config: petals.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
return_future: bool = False,
|
|
|
-) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
|
|
|
+) -> Union[Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]], MPFuture]:
|
|
|
"""
|
|
|
:param uid_or_uids: find one or more modules with these ids from across the DHT
|
|
|
:param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
|
|
@@ -119,15 +119,15 @@ def get_remote_module(
|
|
|
async def _get_remote_module(
|
|
|
dht: DHT,
|
|
|
uid_or_uids: Union[ModuleUID, List[ModuleUID]],
|
|
|
- config: src.DistributedBloomConfig,
|
|
|
+ config: petals.DistributedBloomConfig,
|
|
|
dht_prefix: Optional[str] = None,
|
|
|
-) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
|
|
|
+) -> Union[petals.RemoteTransformerBlock, List[petals.RemoteTransformerBlock]]:
|
|
|
single_uid = isinstance(uid_or_uids, ModuleUID)
|
|
|
uids = [uid_or_uids] if single_uid else uid_or_uids
|
|
|
p2p = await dht.replicate_p2p()
|
|
|
- managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
|
|
+ managers = (petals.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
|
|
|
modules = [
|
|
|
- src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
|
|
+ petals.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
|
|
|
]
|
|
|
return modules[0] if single_uid else modules
|
|
|
|