فهرست منبع

Fix runtime errors

Aleksandr Borzunov 3 سال پیش
والد
کامیت
7d489784e6
3فایلهای تغییر یافته به همراه8 افزوده شده و 8 حذف شده
  1. 3 3
      src/dht_utils.py
  2. 1 1
      src/server/load_balancing.py
  3. 4 4
      src/server/server.py

+ 3 - 3
src/dht_utils.py

@@ -8,11 +8,11 @@ from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P, PeerID
+from hivemind.p2p import P2P
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -102,7 +102,7 @@ def get_remote_module_infos(
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
     )
     return infos[0] if single_uid else infos
 

+ 1 - 1
src/server/load_balancing.py

@@ -1,6 +1,6 @@
 from typing import List, Optional
 
-from src.data_structures import ServerState
+from src.data_structures import RemoteModuleInfo, ServerState
 
 
 def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:

+ 4 - 4
src/server/server.py

@@ -129,6 +129,10 @@ class Server(threading.Thread):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
+        block_config = BloomConfig.from_pretrained(
+            converted_model_name_or_path, use_auth_token=use_auth_token
+        )
+
         if block_indices is not None:
             try:
                 first_block_index, last_block_index = block_indices.split(":")
@@ -143,10 +147,6 @@ class Server(threading.Thread):
             module_infos = get_remote_module_infos(dht, uids)
             block_indices = choose_best_blocks(num_blocks, module_infos)
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
-        )
-
         # initialize modules
         blocks = {}
         for block_index in block_indices: