瀏覽代碼

Simplify rebalancing options

Aleksandr Borzunov 2 年之前
父節點
當前提交
e965719b58
共有 1 個文件被更改,包括 18 次插入11 次删除
  1. 18 11
      src/server/server.py

+ 18 - 11
src/server/server.py

@@ -41,7 +41,6 @@ class Server(threading.Thread):
         throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
-        allow_rebalancing: bool = True,
         num_handlers: int = 8,
         min_batch_size: int = 1,
         max_batch_size: int = 4096,
@@ -60,7 +59,7 @@ class Server(threading.Thread):
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         mean_block_selection_delay: float = 0.5,
-        mean_balance_check_period: float = 300,  # TODO:
+        mean_balance_check_period: float = 300,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -133,16 +132,17 @@ class Server(threading.Thread):
                 logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
                 raise
             block_indices = range(first_block_index, last_block_index)
-        self.block_indices, self.num_blocks = block_indices, num_blocks
-        self.allow_rebalancing = allow_rebalancing
+        self.strict_block_indices, self.num_blocks = block_indices, num_blocks
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_balance_check_period = mean_balance_check_period
+        self._module_infos = None
 
         self.stop = threading.Event()
         if start:
             self.start()
 
     def run(self):
+        self._update_module_infos()
         while True:
             block_indices = self._choose_blocks()
             self.module_container = ModuleContainer.create(
@@ -178,27 +178,34 @@ class Server(threading.Thread):
                     # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
                     if self.stop.wait(timeout):
                         return
+
+                    self._update_module_infos()
                     if self._should_choose_other_blocks():
                         logger.info("Network is imbalanced, server will load other blocks")
                         break  # Stop serving this set of modules
             finally:
                 self.module_container.shutdown()
 
-    def _choose_blocks(self) -> List[int]:
-        if self.block_indices is not None:
-            return self.block_indices
+    def _update_module_infos(self) -> None:
+        if self.strict_block_indices:
+            return  # No need for self._module_infos in this case
 
         # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
         # this delay decreases the probability of a race condition while choosing the best blocks to serve.
         time.sleep(random.random() * 2 * self.mean_block_selection_delay)
 
-        assert self.num_blocks is not None
         uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
-        module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
-        return choose_best_blocks(self.num_blocks, module_infos)
+        self._module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
+
+    def _choose_blocks(self) -> List[int]:
+        if self.strict_block_indices:
+            return self.strict_block_indices
+
+        assert self.num_blocks is not None
+        return choose_best_blocks(self.num_blocks, self._module_infos)
 
     def _should_choose_other_blocks(self) -> bool:
-        if not self.allow_rebalancing:
+        if self.strict_block_indices:
             return False
 
         # TODO: Implement actual algorithm here