|
@@ -9,6 +9,7 @@ import time
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Union
|
|
|
from weakref import WeakMethod
|
|
|
|
|
|
+import numpy as np
|
|
|
from hivemind import DHT, P2P, MSGPackSerializer, PeerID
|
|
|
from hivemind.dht.node import Blacklist
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
@@ -92,12 +93,15 @@ class RemoteSequenceManager:
|
|
|
if await_ready:
|
|
|
self._thread.ready.wait(timeout)
|
|
|
|
|
|
- def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
|
|
|
+ def make_sequence(
|
|
|
+ self, start_index: int = 0, end_index: Optional[int] = None, mode: str = "random"
|
|
|
+ ) -> List[RemoteSpanInfo]:
|
|
|
"""
|
|
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
|
|
|
|
|
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
|
|
:param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
|
|
+ :param mode: either random or fastest
|
|
|
"""
|
|
|
if not self.is_alive():
|
|
|
logger.error("Using a sequence manager that is not running: it has either crashed or never started")
|
|
@@ -110,7 +114,14 @@ class RemoteSequenceManager:
|
|
|
current_index = start_index
|
|
|
while current_index < end_index:
|
|
|
candidate_spans = self.sequence_info.spans_containing_block[current_index]
|
|
|
- chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
|
|
+ if mode == "random":
|
|
|
+ chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
|
|
+ elif mode == "fastest":
|
|
|
+ # note: this too is a heuristic that will be replaced once we integrate fastest wall time routing
|
|
|
+ span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
|
|
|
+ chosen_span = np.random.choice(candidate_spans, p=span_weights / span_weights.sum())
|
|
|
+ else:
|
|
|
+ raise RuntimeError(f"Unexpected mode {mode}")
|
|
|
|
|
|
assert chosen_span.start <= current_index < chosen_span.end
|
|
|
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
|