|
@@ -115,14 +115,14 @@ class RemoteSequenceManager:
|
|
|
self._need_latest_infos = True
|
|
|
|
|
|
def make_sequence(
|
|
|
- self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
|
|
|
+ self, start_index: int = 0, end_index: Optional[int] = None, *, mode: str
|
|
|
) -> List[RemoteSpanInfo]:
|
|
|
"""
|
|
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
|
|
|
|
|
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
|
|
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
|
|
- :param mode: either random or fastest
|
|
|
+ :param mode: one of ["max_throughput", "min_latency"]
|
|
|
"""
|
|
|
with self._thread_start_lock:
|
|
|
if not self.is_alive():
|
|
@@ -137,17 +137,17 @@ class RemoteSequenceManager:
|
|
|
candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
|
|
|
if not candidate_spans:
|
|
|
raise MissingBlocksError(current_index)
|
|
|
- if mode == "random":
|
|
|
- chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
|
|
- elif mode == "fastest":
|
|
|
- # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
|
|
|
+
|
|
|
+ if mode == "max_throughput":
|
|
|
+ span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64)
|
|
|
+ elif mode == "min_latency":
|
|
|
span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
|
|
|
- chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
|
|
else:
|
|
|
raise RuntimeError(f"Unexpected mode {mode}")
|
|
|
+ chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
|
|
|
|
|
assert chosen_span.start <= current_index < chosen_span.end
|
|
|
- span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
|
|
|
+ span_sequence.append(dataclasses.replace(chosen_span, start=current_index))
|
|
|
current_index = chosen_span.end
|
|
|
|
|
|
route_repr = " => ".join([f"{span.start}:{span.end} via …{str(span.peer_id)[-6:]}" for span in span_sequence])
|