|
@@ -50,7 +50,7 @@ class SequenceManagerConfig:
|
|
|
ban_timeout: float = 15 # when a remote peer fails to respond, prevent routing to that peer for this many seconds
|
|
|
active_adapter: Optional[str] = None # name of active LoRA adapter (usually, Hugging Face repo)
|
|
|
|
|
|
- max_pinged: int = 5 # max servers to ping from each sequence side, per update
|
|
|
+ max_pinged: int = 3 # max servers to ping from each sequence side, per update
|
|
|
ping_timeout: float = 2 # max time to wait for pings, per update
|
|
|
|
|
|
|
|
@@ -293,6 +293,8 @@ class RemoteSequenceManager:
|
|
|
return cache_tokens_needed * 2 * span.length <= span.server_info.cache_tokens_left
|
|
|
|
|
|
def _make_sequence_with_max_throughput(self, start_index: int, end_index: int) -> List[RemoteSpanInfo]:
|
|
|
+ client_server_rtts = self.ping_aggregator.to_dict()
|
|
|
+
|
|
|
span_sequence = []
|
|
|
current_index = start_index
|
|
|
while current_index < end_index:
|
|
@@ -300,7 +302,13 @@ class RemoteSequenceManager:
|
|
|
if not candidate_spans:
|
|
|
raise MissingBlocksError(current_index)
|
|
|
|
|
|
- span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
|
|
|
+ # We choose longer servers to minimize the number of hops but leave some randomization
|
|
|
+ # to distribute the load. We also exclude servers known to be unreachable.
|
|
|
+ eps = 1e-6
|
|
|
+ span_weights = np.array(
|
|
|
+ [span.length if client_server_rtts.get(span.peer_id) != np.inf else eps for span in candidate_spans],
|
|
|
+ dtype=np.float64,
|
|
|
+ )
|
|
|
chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
|
|
|
|
|
assert chosen_span.start <= current_index < chosen_span.end
|
|
@@ -361,9 +369,13 @@ class RemoteSequenceManager:
|
|
|
self.state.sequence_info.update_(new_block_infos)
|
|
|
|
|
|
first_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[0]]
|
|
|
+ middle_servers = [
|
|
|
+ span.peer_id for spans in self.state.sequence_info.spans_containing_block[1:-1] for span in spans
|
|
|
+ ]
|
|
|
last_servers = [span.peer_id for span in self.state.sequence_info.spans_containing_block[-1]]
|
|
|
|
|
|
pinged_servers = set(sample_up_to(first_servers, self.config.max_pinged))
|
|
|
+ pinged_servers = set(sample_up_to(middle_servers, self.config.max_pinged))
|
|
|
pinged_servers |= set(sample_up_to(last_servers, self.config.max_pinged))
|
|
|
self.ping_aggregator.ping(list(pinged_servers), wait_timeout=self.config.ping_timeout)
|
|
|
|