5
0
Aleksandr Borzunov 2 жил өмнө
parent
commit
da66b67663

+ 2 - 2
src/petals/client/remote_sequential.py

@@ -7,8 +7,8 @@ from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
+import petals.client
 from petals.client.inference_session import InferenceSession
-from petals.client.remote_model import DistributedBloomConfig
 from petals.client.sequence_manager import RemoteSequenceManager
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
@@ -25,7 +25,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: DistributedBloomConfig,
+        config: petals.client.DistributedBloomConfig,
         dht: DHT,
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,

+ 14 - 13
src/petals/dht_utils.py

@@ -12,7 +12,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
-from petals.client import DistributedBloomConfig, RemoteSequenceManager, RemoteSequential, RemoteTransformerBlock
+import petals.client
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
@@ -76,10 +76,10 @@ def get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: DistributedBloomConfig,
+    config: petals.client.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[RemoteSequential, MPFuture]:
+) -> Union[petals.client.RemoteSequential, MPFuture]:
     return RemoteExpertWorker.run_coroutine(
         _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
     )
@@ -89,22 +89,22 @@ async def _get_remote_sequence(
     dht: DHT,
     start: int,
     stop: int,
-    config: DistributedBloomConfig,
+    config: petals.client.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> RemoteSequential:
+) -> petals.client.RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = RemoteSequenceManager(dht, uids, p2p)
-    return RemoteSequential(config, dht, dht_prefix, p2p, manager)
+    manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
+    return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: DistributedBloomConfig,
+    config: petals.client.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[Union[RemoteTransformerBlock, List[RemoteTransformerBlock]], MPFuture]:
+) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
     :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
@@ -119,15 +119,16 @@ def get_remote_module(
 async def _get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: DistributedBloomConfig,
+    config: petals.client.DistributedBloomConfig,
     dht_prefix: Optional[str] = None,
-) -> Union[RemoteTransformerBlock, List[RemoteTransformerBlock]]:
+) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
     modules = [
-        RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+        petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
+        for m in managers
     ]
     return modules[0] if single_uid else modules