Răsfoiți Sursa

Draft balance check logic

Aleksandr Borzunov 3 ani în urmă
părinte
comite
325ff0cef9
1 a modificat fișierele cu 58 adăugiri și 41 ștergeri
  1. 58 41
      src/server/server.py

+ 58 - 41
src/server/server.py

@@ -60,6 +60,7 @@ class Server(threading.Thread):
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         max_block_selection_delay: float = 1,
+        max_balance_check_period: float = 600,
         use_auth_token: Optional[str] = None,
         load_in_8bit: bool = False,
         *,
@@ -94,8 +95,6 @@ class Server(threading.Thread):
             logger.info(f"Automatic dht prefix: {prefix}")
         self.prefix = prefix
 
-        assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
-
         if expiration is None:
             expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
         self.expiration = expiration
@@ -125,6 +124,7 @@ class Server(threading.Thread):
             revision=revision,
         )
 
+        assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if block_indices is not None:
             try:
                 first_block_index, last_block_index = block_indices.split(":")
@@ -133,51 +133,68 @@ class Server(threading.Thread):
                 logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
                 raise
             block_indices = range(first_block_index, last_block_index)
-        else:
-            # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
-            # this delay decreases the probability of a race condition while choosing the best blocks to serve.
-            time.sleep(random.random() * max_block_selection_delay)
-
-            assert num_blocks is not None
-            uids = [f"{prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
-            module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
-            block_indices = choose_best_blocks(num_blocks, module_infos)
-        self.block_indices = block_indices
+        self.block_indices, self.num_blocks = block_indices, num_blocks
+        self.max_block_selection_delay, self.max_balance_check_period = max_block_selection_delay, max_balance_check_period
 
         self.stop = threading.Event()
         if start:
             self.start()
 
     def run(self):
-        self.module_container = ModuleContainer.create(
-            dht=self.dht,
-            prefix=self.prefix,
-            converted_model_name_or_path=self.converted_model_name_or_path,
-            block_config=self.block_config,
-            memory_cache=self.memory_cache,
-            throughput=self.throughput,
-            block_indices=self.block_indices,
-            num_handlers=self.num_handlers,
-            min_batch_size=self.min_batch_size,
-            max_batch_size=self.max_batch_size,
-            inference_max_length=self.inference_max_length,
-            torch_dtype=self.torch_dtype,
-            cache_dir=self.cache_dir,
-            device=self.device,
-            compression=self.compression,
-            stats_report_interval=self.stats_report_interval,
-            update_period=self.update_period,
-            expiration=self.expiration,
-            prefetch_batches=self.prefetch_batches,
-            sender_threads=self.sender_threads,
-            use_auth_token=self.use_auth_token,
-            load_in_8bit=self.load_in_8bit,
-            start=True,
-        )
-        try:
-            self.stop.wait()
-        finally:
-            self.module_container.shutdown()
+        while True:
+            block_indices = self._choose_blocks()
+            self.module_container = ModuleContainer.create(
+                dht=self.dht,
+                prefix=self.prefix,
+                converted_model_name_or_path=self.converted_model_name_or_path,
+                block_config=self.block_config,
+                memory_cache=self.memory_cache,
+                throughput=self.throughput,
+                block_indices=block_indices,
+                num_handlers=self.num_handlers,
+                min_batch_size=self.min_batch_size,
+                max_batch_size=self.max_batch_size,
+                inference_max_length=self.inference_max_length,
+                torch_dtype=self.torch_dtype,
+                cache_dir=self.cache_dir,
+                device=self.device,
+                compression=self.compression,
+                stats_report_interval=self.stats_report_interval,
+                update_period=self.update_period,
+                expiration=self.expiration,
+                prefetch_batches=self.prefetch_batches,
+                sender_threads=self.sender_threads,
+                use_auth_token=self.use_auth_token,
+                load_in_8bit=self.load_in_8bit,
+                start=True,
+            )
+            try:
+                self.module_container.ready.wait()
+
+                while True:
+                    timeout = random.random() * self.max_balance_check_period
+                    if self.stop.wait(timeout):
+                        return
+                    if self._should_choose_other_blocks():
+                        break  # Stop serving this set of modules
+            finally:
+                self.module_container.shutdown()
+
+    def _choose_blocks(self) -> List[int]:
+        if self.block_indices is not None:
+            return self.block_indices
+
+        # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
+        # this delay decreases the probability of a race condition while choosing the best blocks to serve.
+        time.sleep(random.random() * self.max_block_selection_delay)
+
+        assert self.num_blocks is not None
+        uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+        module_infos = get_remote_module_infos(self.dht, uids, expiration_time=float("inf"))
+        return choose_best_blocks(self.num_blocks, module_infos)
+
+    def _should_choose_other_blocks(self) -> bool:
+        return False
 
     def shutdown(self):
         self.stop.set()