Sfoglia il codice sorgente

Rename --min-balance-quality => --balance-quality

Aleksandr Borzunov 2 anni fa
parent
commit
2e321e506d
3 ha cambiato i file con 9 aggiunte e 9 eliminazioni
  1. 1 1
      cli/run_server.py
  2. 5 5
      src/server/block_selection.py
  3. 3 3
      src/server/server.py

+ 1 - 1
cli/run_server.py

@@ -83,7 +83,7 @@ def main():
                         help='Path of a file with custom nn.modules, wrapped into special decorator')
     parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
 
-    parser.add_argument("--min_balance_quality", type=float, default=0.75,
+    parser.add_argument("--balance_quality", type=float, default=0.75,
                         help="Rebalance the swarm if its throughput is worse than this share of the optimal "
                              "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
                              "on each check for debugging purposes.")

+ 5 - 5
src/server/block_selection.py

@@ -62,9 +62,9 @@ def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModule
 
 
 def should_choose_other_blocks(
-    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], min_balance_quality: float
+    local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
 ) -> bool:
-    if min_balance_quality > 1.0:
+    if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
     spans, throughputs = _compute_spans(module_infos)
@@ -99,8 +99,8 @@ def should_choose_other_blocks(
             throughputs[span.start : span.end] += span.throughput
 
     new_throughput = throughputs.min()
-    balance_quality = initial_throughput / new_throughput
-    logger.info(f"Swarm balance quality: {balance_quality * 100:.1f}%")
+    actual_quality = initial_throughput / new_throughput
+    logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
 
     eps = 1e-6
-    return balance_quality < min_balance_quality - eps
+    return actual_quality < balance_quality - eps

+ 3 - 3
src/server/server.py

@@ -61,7 +61,7 @@ class Server(threading.Thread):
         expiration: Optional[float] = None,
         prefetch_batches: int = 1,
         sender_threads: int = 1,
-        min_balance_quality: float = 0.75,
+        balance_quality: float = 0.75,
         mean_balance_check_period: float = 150,
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
@@ -138,7 +138,7 @@ class Server(threading.Thread):
                 raise
             block_indices = range(first_block_index, last_block_index)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
-        self.min_balance_quality = min_balance_quality
+        self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
 
@@ -215,7 +215,7 @@ class Server(threading.Thread):
             return False
 
         module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
-        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.min_balance_quality)
+        return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
 
     def shutdown(self):
         self.stop.set()