Ver Fonte

Fix load balancing race condition when starting multiple servers at once

Aleksandr Borzunov há 3 anos atrás
pai
commit
d6889cbe96
4 ficheiros alterados com 54 adições e 33 exclusões
  1. 1 1
      src/client/remote_block.py
  2. 16 17
      src/client/remote_sequence_info.py
  3. 18 7
      src/dht_utils.py
  4. 19 8
      src/server/server.py

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers)))  # TODO replace this
         super().__init__(peer_info, p2p)
 
     @property

+ 16 - 17
src/client/remote_sequence_info.py

@@ -1,15 +1,13 @@
 from __future__ import annotations
 
-import dataclasses
 import threading
-from functools import partial
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 
 from hivemind import DHT, PeerID
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from src.data_structures import ModuleUID, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos
+from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.dht_utils import get_remote_module_infos
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -18,21 +16,20 @@ logger = get_logger(__file__)
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
 class RemoteSequenceInfo:
     """Keeps and updates the meta-information about which peers host which blocks"""
 
     dht: DHT
-    block_uids: List[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
+    block_uids: List[ModuleUID]
+    block_infos: List[Optional[RemoteModuleInfo]]
     spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
+    spans_containing_block: Tuple[List[Span]]
     lock_changes: threading.Lock
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
         self.dht = dht
         self.block_uids = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.lock_changes = threading.Lock()
@@ -48,21 +45,17 @@ class RemoteSequenceInfo:
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
 
     def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
-        )
+        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
+            if not info.servers:
                 logger.warning(f"Found no active peers for block {uid}")
             if info.uid != uid:
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
             self.block_infos[block_index] = info
 
     @staticmethod
@@ -70,14 +63,20 @@ class RemoteSequenceInfo:
         closed_spans = []
         active_spans = {}
         for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
+            for peer_id, server in info.servers.items():
+                if server.state != ServerState.ONLINE:
+                    continue
                 if peer_id not in active_spans:
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
 
             for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                if (
+                    peer_id not in info.servers or
+                    info.servers[peer_id].state != ServerState.ONLINE or
+                    block_index == len(block_infos) - 1
+                ):
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
 

+ 18 - 7
src/dht_utils.py

@@ -22,6 +22,7 @@ def declare_active_modules(
     dht: DHT,
     uids: Sequence[ModuleUID],
     expiration_time: DHTExpiration,
+    state: ServerState,
     throughput: float,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
@@ -41,7 +42,13 @@ def declare_active_modules(
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
     return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
+        partial(
+            _declare_active_modules,
+            uids=uids,
+            expiration_time=expiration_time,
+            state=state,
+            throughput=throughput,
+        ),
         return_future=not wait,
     )
 
@@ -51,13 +58,14 @@ async def _declare_active_modules(
     node: DHTNode,
     uids: List[ModuleUID],
     expiration_time: DHTExpiration,
+    state: ServerState,
     throughput: float,
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[throughput] * len(uids),
+        values=[(state.value, throughput)] * len(uids),
         expiration_time=expiration_time,
         num_workers=num_workers,
     )
@@ -123,12 +131,15 @@ async def _get_remote_module_infos(
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
         servers = {}
-        for peer_id, throughput in metadata.value.items():
+        for peer_id, server_info in metadata.value.items():
             try:
-                if not isinstance(throughput.value, float):
-                    raise ValueError(f'Throughput expected to be a float, not {throughput.value}')
-                servers[peer_id] = ServerInfo(ServerState.ONLINE, throughput.value)
-            except (ValueError, TypeError) as e:
+                server_info = server_info.value
+                if not (isinstance(server_info, tuple) and len(server_info) == 2 and
+                        isinstance(server_info[0], int) and isinstance(server_info[1], float)):
+                    raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
+                state, throughput = server_info
+                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+            except ValueError as e:
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)

+ 19 - 8
src/server/server.py

@@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import declare_active_modules, BloomConfig
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
@@ -125,6 +125,9 @@ class Server(threading.Thread):
             )
             logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@@ -151,13 +154,21 @@ class Server(threading.Thread):
         else:
             assert num_blocks is not None
             uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
-            module_infos = get_remote_module_infos(dht, uids)
+            module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
             block_indices = choose_best_blocks(num_blocks, module_infos)
 
-        # initialize modules
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        declare_active_modules(
+            dht,
+            module_uids,
+            expiration_time=get_dht_time() + expiration,
+            state=ServerState.JOINING,
+            throughput=throughput,
+        )
+
+        logger.info(f"Loading blocks with indices {block_indices}")
         blocks = {}
-        for block_index in block_indices:
-            module_uid = f"{prefix}.{block_index}"
+        for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
                 converted_model_name_or_path,
                 block_index,
@@ -252,14 +263,13 @@ class ModuleAnnouncerThread(threading.Thread):
         self,
         module_backends: Dict[str, TransformerBackend],
         dht: DHT,
+        *,
         throughput: float,
         update_period: float = 30,
-        expiration: Optional[int] = None,
+        expiration: float,
         **kwargs
     ):
         super().__init__(**kwargs)
-        if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.module_backends = module_backends
         self.dht = dht
         self.throughput = throughput
@@ -273,6 +283,7 @@ class ModuleAnnouncerThread(threading.Thread):
                 self.dht,
                 self.module_backends.keys(),
                 expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.ONLINE,
                 throughput=self.throughput,
             )
             if self.stop.wait(self.update_period):