Kaynağa Gözat

black-isort

justheuristic 3 yıl önce
ebeveyn
işleme
01b9bced78
4 değiştirilmiş dosya ile 34 ekleme ve 21 silme
  1. 1 1
      src/__init__.py
  2. 4 7
      src/client/remote_block.py
  3. 24 11
      src/dht_utils.py
  4. 5 2
      src/server/server.py

+ 1 - 1
src/__init__.py

@@ -1,3 +1,3 @@
 from .bloom import *
-from .dht_utils import get_remote_module, declare_active_modules
 from .client import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
+from .dht_utils import declare_active_modules, get_remote_module

+ 4 - 7
src/client/remote_block.py

@@ -6,14 +6,14 @@ from typing import Any, AsyncIterator, Dict, Optional
 
 import torch
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
-from hivemind.moe.client.expert import RemoteExpertWorker, RemoteExpert
+from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
 from hivemind.utils import anext, nested_flatten
 
-from src.dht_utils import ModuleUID
 from src.data_structures import RemoteModuleInfo
+from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
 
@@ -21,7 +21,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids))) #TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
         super().__init__(peer_info, p2p)
 
     @property
@@ -37,9 +37,7 @@ class RemoteTransformerBlock(RemoteExpert):
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
 
-    def __init__(
-            self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator
-    ):
+    def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
         self.uid, self.info = uid, info
         # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
@@ -119,4 +117,3 @@ class RemoteTransformerBlockInferenceSession:
 
     def __exit__(self, *exc_details):
         self.close()
-

+ 24 - 11
src/dht_utils.py

@@ -8,25 +8,29 @@ from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, use_hivemind_log_handler, get_logger
 from hivemind.p2p import P2P, PeerID
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
-from src.data_structures import RemoteModuleInfo, ModuleUID, UID_DELIMITER
-
+from src.data_structures import UID_DELIMITER, ModuleUID, RemoteModuleInfo
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
 def declare_active_modules(
-    dht: DHT, uids: Sequence[ModuleUID], expiration_time: DHTExpiration, wait: bool = True
+    dht: DHT,
+    uids: Sequence[ModuleUID],
+    expiration_time: DHTExpiration,
+    throughput: Optional[float] = None,
+    wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
     """
     Declare that your node serves the specified modules; update timestamps if declared previously
 
     :param uids: a list of module ids to declare
     :param wait: if True, awaits for declaration to finish, otherwise runs in background
+    :param throughput: optionally specify your performance in terms of compute throughput
     :param expiration_time: declated modules will be visible for this many seconds
     :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
     """
@@ -37,25 +41,33 @@ def declare_active_modules(
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid
     return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time), return_future=not wait
+        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
+        return_future=not wait,
     )
 
 
 async def _declare_active_modules(
-    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: DHTExpiration
+    dht: DHT,
+    node: DHTNode,
+    uids: List[ModuleUID],
+    expiration_time: DHTExpiration,
+    throughput: Optional[float] = None,
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[None] * len(uids),
+        values=[throughput] * len(uids),
         expiration_time=expiration_time,
-        num_workers=num_workers
+        num_workers=num_workers,
     )
 
 
 def get_remote_module(
-    dht: DHT, uid_or_uids: Union[ModuleUID, List[ModuleUID]], expiration_time: Optional[DHTExpiration] = None, return_future: bool = False
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    expiration_time: Optional[DHTExpiration] = None,
+    return_future: bool = False,
 ) -> Union[List[Optional["src.RemoteTransformerBlock"]], MPFuture[List[Optional["src.RemoteTransformerBlock"]]]]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
@@ -70,6 +82,7 @@ def get_remote_module(
     )
 
     if return_future:
+
         async def _unpack(infos_future: MPFuture, dht: DHT):
             p2p = await dht.replicate_p2p()
             modules = _create_remote_modules_from_infos(await infos_future, p2p)
@@ -82,7 +95,7 @@ def get_remote_module(
 
 
 async def _get_remote_module_infos(
-        dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
+    dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
 ) -> List[Optional[RemoteModuleInfo]]:
     if expiration_time is None:
         expiration_time = get_dht_time()
@@ -115,4 +128,4 @@ def _create_remote_modules_from_infos(
             modules.append(src.RemoteTransformerBlock(info, p2p))
         else:
             modules.append(None)
-    return modules
+    return modules

+ 5 - 2
src/server/server.py

@@ -5,7 +5,7 @@ import threading
 from typing import Dict, Optional, Sequence, Union
 
 import torch
-from hivemind import DHT, BatchTensorDescriptor, MAX_DHT_TIME_DISCREPANCY_SECONDS, get_dht_time
+from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.dht_handler import DHTHandlerThread
 from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
@@ -43,7 +43,9 @@ class Server(threading.Thread):
             TransformerConnectionHandler(dht, self.module_backends) for _ in range(num_connection_handlers)
         ]
         self.runtime = Runtime(self.module_backends, device=device, **kwargs)
-        self.dht_handler_thread = ModuleAnnouncerThread(self.module_backends, dht, update_period, expiration, daemon=True)
+        self.dht_handler_thread = ModuleAnnouncerThread(
+            self.module_backends, dht, update_period, expiration, daemon=True
+        )
         self.checkpoint_saver = None  # no need to save checkpoints since we do not change model state
 
         if start:
@@ -217,6 +219,7 @@ class Server(threading.Thread):
 
 class ModuleAnnouncerThread(threading.Thread):
     """Periodically announces that this server hosts the specified modules, visible to all DHT peers"""
+
     def __init__(
         self, module_backends, dht: DHT, update_period: float = 30, expiration: Optional[int] = None, **kwargs
     ):