Prechádzať zdrojové kódy

Start implementing load balancing

Aleksandr Borzunov 3 rokov pred
rodič
commit
be71cf7992

+ 21 - 2
src/data_structures.py

@@ -1,8 +1,27 @@
-from typing import Collection, NamedTuple
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict
 
 from hivemind import PeerID
 
 ModuleUID = str
 UID_DELIMITER = "."  # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
 CHAIN_DELIMITER = " "  # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
-RemoteModuleInfo = NamedTuple("RemoteModuleInfo", [("uid", ModuleUID), ("peer_ids", Collection[PeerID])])
+
+
+class ServerState(Enum):
+    OFFLINE = 0
+    JOINING = 1
+    ONLINE = 2
+
+
+@dataclass
+class ServerInfo:
+    state: ServerState
+    throughput: float
+
+
+@dataclass
+class RemoteModuleInfo:
+    uid: ModuleUID
+    servers: Dict[PeerID, ServerInfo]

+ 22 - 7
src/dht_utils.py

@@ -94,6 +94,19 @@ def get_remote_module(
     return modules[0] if single_uid else modules
 
 
+def get_remote_module_infos(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    expiration_time: Optional[DHTExpiration] = None,
+) -> List[Optional[RemoteModuleInfo]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    infos = dht.run_coroutine(
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
+    )
+    return infos[0] if single_uid else infos
+
+
 async def _get_remote_module_infos(
     dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
 ) -> List[Optional[RemoteModuleInfo]]:
@@ -109,14 +122,16 @@ async def _get_remote_module_infos(
             if metadata is not None:
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
-        valid_entries = set()
-        for maybe_peer_id, _unused_value in metadata.value.items():
+        servers = {}
+        for peer_id, throughput in metadata.value.items():
+            if throughput is None:
+                throughput = 0.0  # FIXME:
             try:
-                valid_entries.add(PeerID.from_base58(maybe_peer_id))
-            except:
-                logger.error(f"Incorrect peer entry for {uid}: {maybe_peer_id}")
-        if valid_entries:
-            modules[i] = RemoteModuleInfo(uid, valid_entries)
+                servers[peer_id] = ServerInfo(ServerState.ONLINE, throughput)
+            except (ValueError, TypeError):
+                logger.error(f"Incorrect peer entry for {uid}: {peer_id}")
+        if servers:
+            modules[i] = RemoteModuleInfo(uid, servers)
     return modules
 
 

+ 18 - 0
src/server/load_balancing.py

@@ -0,0 +1,18 @@
+from typing import List, Optional
+
+from src.data_structures import ServerState
+
+
+def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+    throughputs = []
+    for module in remote_module_infos:
+        if module is None:
+            throughputs.append(0)
+            continue
+        throughputs.append(sum(server.throughput for server in module.server.values()
+                               if server.state != ServerState.OFFLINE))
+
+    options = [(throughputs[i:i + num_blocks], i)
+               for i in range(0, len(throughputs) - num_blocks + 1)]
+    best_start = min(options)[1]
+    return list(range(best_start, best_start + num_blocks))

+ 5 - 1
src/server/server.py

@@ -14,9 +14,11 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from src import declare_active_modules, BloomConfig
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
+from src.dht_utils import get_remote_module_infos
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
+from src.server.load_balancing import choose_best_blocks
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -137,7 +139,9 @@ class Server(threading.Thread):
             block_indices = range(first_block_index, last_block_index)
         else:
             assert num_blocks is not None
-            block_indices = range(num_blocks)  # TODO replace with proper load balancing
+            uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
+            module_infos = get_remote_module_infos(dht, uids)
+            block_indices = choose_best_blocks(num_blocks, module_infos)
 
         block_config = BloomConfig.from_pretrained(
             converted_model_name_or_path, use_auth_token=use_auth_token