Browse Source

Fix load balancing race condition when starting multiple servers at once

Aleksandr Borzunov 3 years ago
parent
commit
d6889cbe96
4 changed files with 54 additions and 33 deletions
  1. 1 1
      src/client/remote_block.py
  2. 16 17
      src/client/remote_sequence_info.py
  3. 18 7
      src/dht_utils.py
  4. 19 8
      src/server/server.py

+ 1 - 1
src/client/remote_block.py

@@ -25,7 +25,7 @@ class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
 
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
     def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.peer_ids)))  # TODO replace this
+        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers)))  # TODO replace this
         super().__init__(peer_info, p2p)
         super().__init__(peer_info, p2p)
 
 
     @property
     @property

+ 16 - 17
src/client/remote_sequence_info.py

@@ -1,15 +1,13 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import dataclasses
 import threading
 import threading
-from functools import partial
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 from typing import List, NamedTuple, Optional, Sequence, Tuple
 
 
 from hivemind import DHT, PeerID
 from hivemind import DHT, PeerID
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src.data_structures import ModuleUID, RemoteModuleInfo
-from src.dht_utils import _get_remote_module_infos
+from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.dht_utils import get_remote_module_infos
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -18,21 +16,20 @@ logger = get_logger(__file__)
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
 
 
 
 
-@dataclasses.dataclass(frozen=False, init=False)  # TODO[borzunov@] eto ne dataclass
 class RemoteSequenceInfo:
 class RemoteSequenceInfo:
     """Keeps and updates the meta-information about which peers host which blocks"""
     """Keeps and updates the meta-information about which peers host which blocks"""
 
 
     dht: DHT
     dht: DHT
-    block_uids: List[ModuleUID, ...]
-    block_infos: List[Optional[RemoteModuleInfo], ...]
+    block_uids: List[ModuleUID]
+    block_infos: List[Optional[RemoteModuleInfo]]
     spans_by_priority: List[Span]  # sorted from best to worst
     spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span], ...]
+    spans_containing_block: Tuple[List[Span]]
     lock_changes: threading.Lock
     lock_changes: threading.Lock
 
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
         self.dht = dht
         self.dht = dht
         self.block_uids = list(block_uids)
         self.block_uids = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
@@ -48,21 +45,17 @@ class RemoteSequenceInfo:
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
 
 
     def update_block_infos_(self):
     def update_block_infos_(self):
-        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
-            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")), return_future=False
-        )
+        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
         assert len(new_block_infos) == len(self.block_uids)
         assert len(new_block_infos) == len(self.block_uids)
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
                 logger.warning(f"Found no block info for block {uid}")
             if not isinstance(info, RemoteModuleInfo):
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.peer_ids:
+            if not info.servers:
                 logger.warning(f"Found no active peers for block {uid}")
                 logger.warning(f"Found no active peers for block {uid}")
             if info.uid != uid:
             if info.uid != uid:
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
                 logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            if not isinstance(info.peer_ids, set):
-                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
             self.block_infos[block_index] = info
             self.block_infos[block_index] = info
 
 
     @staticmethod
     @staticmethod
@@ -70,14 +63,20 @@ class RemoteSequenceInfo:
         closed_spans = []
         closed_spans = []
         active_spans = {}
         active_spans = {}
         for block_index, info in enumerate(block_infos):
         for block_index, info in enumerate(block_infos):
-            for peer_id in info.peer_ids:
+            for peer_id, server in info.servers.items():
+                if server.state != ServerState.ONLINE:
+                    continue
                 if peer_id not in active_spans:
                 if peer_id not in active_spans:
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                     active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                 else:  # peer_id in active_spans
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
                     active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
 
 
             for peer_id in list(active_spans.keys()):
             for peer_id in list(active_spans.keys()):
-                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                if (
+                    peer_id not in info.servers or
+                    info.servers[peer_id].state != ServerState.ONLINE or
+                    block_index == len(block_infos) - 1
+                ):
                     closed_spans.append(active_spans.pop(peer_id))
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
         assert not active_spans
 
 

+ 18 - 7
src/dht_utils.py

@@ -22,6 +22,7 @@ def declare_active_modules(
     dht: DHT,
     dht: DHT,
     uids: Sequence[ModuleUID],
     uids: Sequence[ModuleUID],
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
+    state: ServerState,
     throughput: float,
     throughput: float,
     wait: bool = True,
     wait: bool = True,
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
 ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
@@ -41,7 +42,13 @@ def declare_active_modules(
     for uid in uids:
     for uid in uids:
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
         assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
     return dht.run_coroutine(
     return dht.run_coroutine(
-        partial(_declare_active_modules, uids=uids, expiration_time=expiration_time, throughput=throughput),
+        partial(
+            _declare_active_modules,
+            uids=uids,
+            expiration_time=expiration_time,
+            state=state,
+            throughput=throughput,
+        ),
         return_future=not wait,
         return_future=not wait,
     )
     )
 
 
@@ -51,13 +58,14 @@ async def _declare_active_modules(
     node: DHTNode,
     node: DHTNode,
     uids: List[ModuleUID],
     uids: List[ModuleUID],
     expiration_time: DHTExpiration,
     expiration_time: DHTExpiration,
+    state: ServerState,
     throughput: float,
     throughput: float,
 ) -> Dict[ModuleUID, bool]:
 ) -> Dict[ModuleUID, bool]:
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
     return await node.store_many(
     return await node.store_many(
         keys=uids,
         keys=uids,
         subkeys=[dht.peer_id.to_base58()] * len(uids),
         subkeys=[dht.peer_id.to_base58()] * len(uids),
-        values=[throughput] * len(uids),
+        values=[(state.value, throughput)] * len(uids),
         expiration_time=expiration_time,
         expiration_time=expiration_time,
         num_workers=num_workers,
         num_workers=num_workers,
     )
     )
@@ -123,12 +131,15 @@ async def _get_remote_module_infos(
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
                 logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
             continue
         servers = {}
         servers = {}
-        for peer_id, throughput in metadata.value.items():
+        for peer_id, server_info in metadata.value.items():
             try:
             try:
-                if not isinstance(throughput.value, float):
-                    raise ValueError(f'Throughput expected to be a float, not {throughput.value}')
-                servers[peer_id] = ServerInfo(ServerState.ONLINE, throughput.value)
-            except (ValueError, TypeError) as e:
+                server_info = server_info.value
+                if not (isinstance(server_info, tuple) and len(server_info) == 2 and
+                        isinstance(server_info[0], int) and isinstance(server_info[1], float)):
+                    raise ValueError(f"Invalid server info for uid={uid}, peer_id={peer_id}: {server_info}")
+                state, throughput = server_info
+                servers[peer_id] = ServerInfo(ServerState(state), throughput)
+            except ValueError as e:
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
                 logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
         if servers:
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
             modules[i] = RemoteModuleInfo(uid, servers)

+ 19 - 8
src/server/server.py

@@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
 from src import declare_active_modules, BloomConfig
 from src import declare_active_modules, BloomConfig
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER
+from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
 from src.server.backend import TransformerBackend
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.cache import MemoryCache
@@ -125,6 +125,9 @@ class Server(threading.Thread):
             )
             )
             logger.info(f"Automatic dht prefix: {prefix}")
             logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
+        if expiration is None:
+            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
+
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
         logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
@@ -151,13 +154,21 @@ class Server(threading.Thread):
         else:
         else:
             assert num_blocks is not None
             assert num_blocks is not None
             uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
             uids = [f"{prefix}.{block_index}" for block_index in range(block_config.n_layer)]
-            module_infos = get_remote_module_infos(dht, uids)
+            module_infos = get_remote_module_infos(dht, uids, expiration_time=float("inf"))
             block_indices = choose_best_blocks(num_blocks, module_infos)
             block_indices = choose_best_blocks(num_blocks, module_infos)
 
 
-        # initialize modules
+        module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
+        declare_active_modules(
+            dht,
+            module_uids,
+            expiration_time=get_dht_time() + expiration,
+            state=ServerState.JOINING,
+            throughput=throughput,
+        )
+
+        logger.info(f"Loading blocks with indices {block_indices}")
         blocks = {}
         blocks = {}
-        for block_index in block_indices:
-            module_uid = f"{prefix}.{block_index}"
+        for module_uid, block_index in zip(module_uids, block_indices):
             block = load_pretrained_block(
             block = load_pretrained_block(
                 converted_model_name_or_path,
                 converted_model_name_or_path,
                 block_index,
                 block_index,
@@ -252,14 +263,13 @@ class ModuleAnnouncerThread(threading.Thread):
         self,
         self,
         module_backends: Dict[str, TransformerBackend],
         module_backends: Dict[str, TransformerBackend],
         dht: DHT,
         dht: DHT,
+        *,
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
-        expiration: Optional[int] = None,
+        expiration: float,
         **kwargs
         **kwargs
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
-        if expiration is None:
-            expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.module_backends = module_backends
         self.module_backends = module_backends
         self.dht = dht
         self.dht = dht
         self.throughput = throughput
         self.throughput = throughput
@@ -273,6 +283,7 @@ class ModuleAnnouncerThread(threading.Thread):
                 self.dht,
                 self.dht,
                 self.module_backends.keys(),
                 self.module_backends.keys(),
                 expiration_time=get_dht_time() + self.expiration,
                 expiration_time=get_dht_time() + self.expiration,
+                state=ServerState.ONLINE,
                 throughput=self.throughput,
                 throughput=self.throughput,
             )
             )
             if self.stop.wait(self.update_period):
             if self.stop.wait(self.update_period):