Просмотр исходного кода

Prefer longer servers for fine-tuning, exclude unreachable (#448)

We choose longer servers to minimize the number of hops but leave some randomization to distribute the load. We also exclude servers known to be unreachable.
Alexander Borzunov 2 лет назад
Родитель
Сommit
2a150770a4
1 измененных файлов с 14 добавлено и 2 удалено
  1. 14 2
      src/petals/client/routing/sequence_manager.py

+ 14 - 2
src/petals/client/routing/sequence_manager.py

@@ -50,7 +50,7 @@ class SequenceManagerConfig:
     ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
     active_adapter: Optional[str] = None  # name of active LoRA adapter (usually, Hugging Face repo)
 
-    max_pinged: int = 5  # max servers to ping from each sequence side, per update
+    max_pinged: int = 3  # max servers to ping from each sequence side, per update
     ping_timeout: float = 2  # max time to wait for pings, per update
 
 
@@ -293,6 +293,8 @@ class RemoteSequenceManager:
         return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
 
     def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
+        client_server_rtts = self.ping_aggregator.to_dict()
+
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
@@ -300,7 +302,13 @@ class RemoteSequenceManager:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
 
-            span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
+            # We choose longer servers to minimize the number of hops but leave some randomization
+            # to distribute the load. We also exclude servers known to be unreachable.
+            eps = 1e-6
+            span_weights = np.array(
+                [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
+                dtype=np.float64,
+            )
             chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
 
             assert chosen_span.start <= current_index < chosen_span.end
@@ -361,9 +369,13 @@ class RemoteSequenceManager:
             self.state.sequence_info.update_(new_block_infos)
 
             first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
+            middle_servers = [
+                span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
+            ]
             last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
 
         pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
+        pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
         pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
         self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)