瀏覽代碼

Penalize servers that use relays during rebalancing (#428)

Servers accessible only via relays may introduce issues if they are the only type of servers holding certain blocks. Specifically, a connection to such servers may be unstable or opened after a certain delay.

This PR changes their self-reported throughput, so that the rebalancing algorithm prefers to put directly available servers for hosting each block.
Alexander Borzunov 2 年之前
父節點
當前提交
351e96bc46
共有 3 個文件被更改,包括 16 次插入18 次删除
  1. 2 10
      src/petals/client/routing/sequence_manager.py
  2. 8 7
      src/petals/server/server.py
  3. 6 1
      src/petals/server/throughput.py

+ 2 - 10
src/petals/client/routing/sequence_manager.py

@@ -292,9 +292,7 @@ class RemoteSequenceManager:
         # This is okay since false positives are more costly than false negatives here.
         return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
 
-    def _make_sequence_with_max_throughput(
-        self, start_index: int, end_index: int, *, relay_penalty: float = 0.5
-    ) -> List[RemoteSpanInfo]:
+    def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
@@ -302,13 +300,7 @@ class RemoteSequenceManager:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
 
-            span_weights = np.array(
-                [
-                    span.server_info.throughput * (1 if not span.server_info.using_relay else relay_penalty)
-                    for span in candidate_spans
-                ],
-                dtype=np.float64,
-            )
+            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end

+ 8 - 7
src/petals/server/server.py

@@ -83,7 +83,7 @@ class Server:
         quant_type: Optional[QuantType] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
-        dht_client_mode: Optional[bool] = None,
+        reachable_via_relay: Optional[bool] = None,
         use_relay: bool = True,
         use_auto_relay: bool = True,
         adapters: Sequence[str] = (),
@@ -129,20 +129,20 @@ class Server:
             for block_index in range(self.block_config.num_hidden_layers)
         ]
 
-        if dht_client_mode is None:
+        if reachable_via_relay is None:
             is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
-            dht_client_mode = is_reachable is False  # if could not check reachability (returns None), run a full peer
-            logger.info(f"This server is accessible {'via relays' if dht_client_mode else 'directly'}")
+            reachable_via_relay = is_reachable is False  # if can't check reachability (returns None), run a full peer
+            logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
         self.dht = DHT(
             initial_peers=initial_peers,
             start=True,
             num_workers=self.block_config.num_hidden_layers,
             use_relay=use_relay,
             use_auto_relay=use_auto_relay,
-            client_mode=dht_client_mode,
+            client_mode=reachable_via_relay,
             **kwargs,
         )
-        self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
+        self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
 
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
@@ -227,6 +227,7 @@ class Server:
                 num_blocks=num_blocks,
                 quant_type=quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
+                reachable_via_relay=reachable_via_relay,
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
             )
@@ -239,7 +240,7 @@ class Server:
             adapters=tuple(adapters),
             torch_dtype=str(torch_dtype).replace("torch.", ""),
             quant_type=quant_type.name.lower(),
-            using_relay=self.dht.client_mode,
+            using_relay=reachable_via_relay,
             **throughput_info,
         )
 

+ 6 - 1
src/petals/server/throughput.py

@@ -41,6 +41,8 @@ def get_server_throughput(
     num_blocks: int,
     quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
+    reachable_via_relay: bool,
+    relay_penalty: float = 0.2,
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
 ) -> Dict[str, float]:
@@ -94,7 +96,10 @@ def get_server_throughput(
     # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
     average_blocks_used = (num_blocks + 1) / 2
     throughput = throughput_info["forward_rps"] / average_blocks_used
-    throughput = min(throughput, throughput_info.get("network_rps", math.inf))
+
+    network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
+    throughput = min(throughput, network_rps)
+
     throughput_info["throughput"] = throughput
     logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")