Jelajahi Sumber

Implement rebalancing criterion

Aleksandr Borzunov 2 tahun lalu
induk
melakukan
4a05f40786
2 mengubah file dengan 110 tambahan dan 33 penghapusan
  1. 94 11
      src/server/block_selection.py
  2. 16 22
      src/server/server.py

+ 94 - 11
src/server/block_selection.py

@@ -1,18 +1,101 @@
-from typing import List, Optional
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+from hivemind import PeerID, get_logger
 
 from src.data_structures import RemoteModuleInfo, ServerState
 
+logger = get_logger(__file__)
+
+
+@dataclass
+class Span:
+    start: int
+    end: int
+    throughput: float
+
+    @property
+    def length(self):
+        return self.end - self.start
+
+    def move_to(self, new_start: int) -> None:
+        self.start, self.end = new_start, new_start + self.length
+
 
-def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
-    throughputs = []
-    for module in remote_module_infos:
+def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
+    spans = {}
+    throughputs = np.zeros(len(module_infos))
+    for block, module in enumerate(module_infos):
         if module is None:
-            throughputs.append(0)
             continue
-        throughputs.append(
-            sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
-        )
 
-    options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
-    best_start = min(options)[1]
-    return list(range(best_start, best_start + num_blocks))
+        for peer_id, server in module.servers.items():
+            if server.state == ServerState.OFFLINE:
+                continue
+
+            if peer_id in spans:
+                spans[peer_id].start = min(spans[peer_id].start, block)
+                spans[peer_id].end = max(spans[peer_id].start, block + 1)
+            else:
+                spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
+
+            throughputs[block] += server.throughput
+
+    return spans, throughputs
+
+
+def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
+    options = (
+        (sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
+        for i in range(0, len(throughputs) - num_blocks + 1)
+    )
+    return min(options)[-1]
+
+
+def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
+    _, throughputs = _compute_spans(module_infos)
+    start = _choose_best_start(throughputs, num_blocks, None)
+    return list(range(start, start + num_blocks))
+
+
+def should_choose_other_blocks(
+    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
+) -> bool:
+    spans, throughputs = _compute_spans(module_infos)
+    initial_throughput = throughputs.min()
+
+    assert local_peer_id in spans, "Span served by this server is not present in the DHT"
+    local_span = spans[local_peer_id]
+    throughputs[local_span.start : local_span.end] -= local_span.throughput
+
+    new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
+    if local_span.start == new_start:
+        return False  # This server is on its best place already
+    local_span.move_to(new_start)
+
+    throughputs[local_span.start : local_span.end] += local_span.throughput
+
+    moved = True
+    while moved:
+        servers = list(spans.keys())
+        np.random.shuffle(servers)
+
+        moved = False
+        for peer_id in servers:
+            span = spans[peer_id]
+            throughputs[span.start : span.end] -= span.throughput
+
+            new_start = _choose_best_start(throughputs, span.length, span.start)
+            if span.start != new_start:
+                span.move_to(new_start)
+                moved = True
+
+            throughputs[span.start : span.end] += span.throughput
+
+    new_throughput = throughputs.min()
+    balance_quality = initial_throughput / new_throughput
+    logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")
+
+    eps = 1e-6
+    return balance_quality < min_balance_quality - eps

+ 16 - 22
src/server/server.py

@@ -6,6 +6,7 @@ import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
 
+import numpy as np
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -17,8 +18,8 @@ from src import BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
+from src.server import block_selection
 from src.server.backend import TransformerBackend
-from src.server.block_selection import choose_best_blocks
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
 from src.server.throughput import get_host_throughput
@@ -59,7 +60,8 @@ class Server(threading.Thread):
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         mean_block_selection_delay: float = 0.5,
-        mean_balance_check_period: float = 300,
+        mean_balance_check_period: float = 150,
+        min_balance_quality: float = 0.8,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -122,6 +124,7 @@ class Server(threading.Thread):
             use_auth_token=use_auth_token,
             revision=revision,
         )
+        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if block_indices is not None:
@@ -135,14 +138,13 @@ class Server(threading.Thread):
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_balance_check_period = mean_balance_check_period
-        self._module_infos = None
+        self.min_balance_quality = min_balance_quality
 
         self.stop = threading.Event()
         if start:
             self.start()
 
     def run(self):
-        self._update_module_infos()
         while True:
             block_indices = self._choose_blocks()
             self.module_container = ModuleContainer.create(
@@ -179,37 +181,29 @@ class Server(threading.Thread):
                     if self.stop.wait(timeout):
                         return
 
-                    self._update_module_infos()
                     if self._should_choose_other_blocks():
-                        logger.info("Network is imbalanced, server will load other blocks")
+                        logger.info("Swarm is imbalanced, server will load other blocks")
                         break  # Stop serving this set of modules
             finally:
                 self.module_container.shutdown()
 
-    def _update_module_infos(self) -> None:
-        if self.strict_block_indices:
-            return  # No need for self._module_infos in this case
+    def _choose_blocks(self) -> List[int]:
+        if self.strict_block_indices is not None:
+            return self.strict_block_indices
+        assert self.num_blocks is not None
 
         # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
         # this delay decreases the probability of a race condition while choosing the best blocks to serve.
         time.sleep(random.random() * 2 * self.mean_block_selection_delay)
-
-        uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
-        self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
-
-    def _choose_blocks(self) -> List[int]:
-        if self.strict_block_indices:
-            return self.strict_block_indices
-
-        assert self.num_blocks is not None
-        return choose_best_blocks(self.num_blocks, self._module_infos)
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        return block_selection.choose_best_blocks(self.num_blocks, module_infos)
 
     def _should_choose_other_blocks(self) -> bool:
-        if self.strict_block_indices:
+        if self.strict_block_indices is not None:
             return False
 
-        # TODO: Implement actual algorithm here
-        return True
+        module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
+        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
 
     def shutdown(self):
         self.stop.set()