|
@@ -1,5 +1,6 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
+import enum
|
|
import random
|
|
import random
|
|
import threading
|
|
import threading
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
from typing import List, Optional, Sequence, Tuple, Union
|
|
@@ -17,65 +18,145 @@ use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__file__)
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
|
-class RemoteSequenceManager:
|
|
|
|
|
|
+class RoutingStrategy(enum.Enum):
|
|
|
|
+ RANDOM = enum.auto() # choose a random compatible server at each branch and include all layers served by it
|
|
|
|
+ FASTEST = enum.auto() # [WIP] minimize the estimated time to process a given number of tokens, including latency
|
|
|
|
+ LOAD_BALANCED = enum.auto() # [WIP] use servers in proportion to their speed, on average over many sequences
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class RemoteSequenceManager(threading.Thread):
|
|
"""
|
|
"""
|
|
- Keeps and updates the meta-information about which peers host which blocks.
|
|
|
|
- In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
|
|
|
|
|
|
+ Sequence manager is a thread that keeps track of information on remote servers that constitute a RemoteSequential.
|
|
|
|
+ TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
|
|
|
|
+
|
|
|
|
+ When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
|
|
|
|
+ Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
|
|
|
|
+
|
|
|
|
+ To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
|
|
|
|
+
|
|
|
|
+ :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
|
|
|
|
+ running redundant sequence managers for the same set of layers.
|
|
|
|
+
|
|
|
|
+ Example
|
|
|
|
+ =======
|
|
|
|
+ >>> sequence_manager = RemoteSequenceManager(dht=..., block_uids=('me/my-model.0', 'me/my-model.1', 'me/my-model.2')
|
|
|
|
+ >>> seq1_full_model = sequence_manager.make_sequence()
|
|
|
|
+ >>> seq2_partial = sequence_manager.make_sequence(start_index=0, end_index=2) # the end index is exclusive
|
|
|
|
+ >>> seq1_fastest = sequence_manager.make_sequence()
|
|
|
|
+
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
|
|
|
|
- self.dht, self.p2p = dht, p2p
|
|
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ dht: DHT,
|
|
|
|
+ block_uids: Sequence[ModuleUID],
|
|
|
|
+ *,
|
|
|
|
+ p2p: Optional[P2P] = None,
|
|
|
|
+ start: bool,
|
|
|
|
+ max_retries: int = 3,
|
|
|
|
+ update_period: float = 30,
|
|
|
|
+ ): # NB: if you add any more parameters, please make sure you pass them to sub-sequences in .__getitem__ below!
|
|
|
|
+ super().__init__(daemon=True)
|
|
|
|
+ self.dht, self.p2p = dht, (p2p if p2p is not None else dht.replicate_p2p())
|
|
self.block_uids: List[ModuleUID] = list(block_uids)
|
|
self.block_uids: List[ModuleUID] = list(block_uids)
|
|
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
|
|
self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
|
|
self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
|
|
self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
|
|
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
|
|
self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
|
|
|
|
+
|
|
|
|
+ self.update_period, self.max_retries = update_period, max_retries
|
|
self.last_update_time: DHTExpiration = -float("inf")
|
|
self.last_update_time: DHTExpiration = -float("inf")
|
|
- self.max_retries = max_retries
|
|
|
|
|
|
+
|
|
self._rpc_info = None
|
|
self._rpc_info = None
|
|
- self.lock_changes = threading.Lock()
|
|
|
|
- self.update_()
|
|
|
|
|
|
+ self._lock_changes = threading.Lock()
|
|
|
|
+ self.ready = threading.Event() # whether or not this thread is ready to make_sequence
|
|
|
|
+
|
|
|
|
+ if start:
|
|
|
|
+ self.run_in_background()
|
|
|
|
|
|
for uid, info in zip(self.block_uids, self.block_infos):
|
|
for uid, info in zip(self.block_uids, self.block_infos):
|
|
assert info is not None, f"Found no remote peers for block {uid}"
|
|
assert info is not None, f"Found no remote peers for block {uid}"
|
|
assert self.spans_by_priority and self.spans_containing_block
|
|
assert self.spans_by_priority and self.spans_containing_block
|
|
|
|
|
|
- def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
|
|
|
|
|
|
+ def make_sequence(
|
|
|
|
+ self,
|
|
|
|
+ start_index: int = 0,
|
|
|
|
+ end_index: Optional[int] = None,
|
|
|
|
+ strategy: RoutingStrategy = RoutingStrategy.RANDOM,
|
|
|
|
+ num_tokens: Optional[int] = None,
|
|
|
|
+ ) -> Sequence[RemoteSpanInfo]:
|
|
"""
|
|
"""
|
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
|
Form a sequence of remote servers that collectively serve all consecutive layers
|
|
|
|
|
|
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
|
:param start_index: optional index of the first module in a sequence, default = the first of block_uids
|
|
- :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
|
|
|
|
|
|
+ :param end_index: optional index of the last module (non-inclusive), default = after last of block_uids
|
|
|
|
+ :param strategy: the routing algorithm to use (e.g. random, fastest, balanced), see RoutingStrategy for details
|
|
|
|
+ :param num_tokens: the number of tokens sent through this sequence at a time, used by RoutingStrategy.FASTEST
|
|
"""
|
|
"""
|
|
|
|
+ assert self.is_alive()
|
|
|
|
+ if not self.ready.is_set():
|
|
|
|
+ logger.warning(f"{self.__class__.__name__} is still initializing, waiting until it's ready...")
|
|
|
|
+ self.ready.wait()
|
|
|
|
+ logger.warning(f"Finished waiting for {self.__class__.__name__} to initialize")
|
|
|
|
+ if (strategy is RoutingStrategy.FASTEST) != (num_tokens is not None):
|
|
|
|
+ logger.warning("please specify num_tokens with FASTEST strategy (and only with FASTEST strategy)")
|
|
end_index = end_index if end_index is not None else len(self.block_uids)
|
|
end_index = end_index if end_index is not None else len(self.block_uids)
|
|
- span_sequence = []
|
|
|
|
- current_index = start_index
|
|
|
|
- while current_index < end_index:
|
|
|
|
- candidate_spans = self.spans_containing_block[current_index]
|
|
|
|
- chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
|
|
|
|
|
|
|
|
- assert chosen_span.start <= current_index < chosen_span.end
|
|
|
|
- span_sequence.append(chosen_span)
|
|
|
|
- current_index = chosen_span.end
|
|
|
|
|
|
+ if strategy == RoutingStrategy.RANDOM:
|
|
|
|
+ span_sequence = []
|
|
|
|
+ current_index = start_index
|
|
|
|
+ while current_index < end_index:
|
|
|
|
+ candidate_spans = self.spans_containing_block[current_index]
|
|
|
|
+ chosen_span = random.choice(candidate_spans)
|
|
|
|
+ assert chosen_span.start <= current_index < chosen_span.end
|
|
|
|
+ span_sequence.append(chosen_span)
|
|
|
|
+ current_index = chosen_span.end
|
|
|
|
+ return span_sequence
|
|
|
|
+ elif strategy == RoutingStrategy.FASTEST:
|
|
|
|
+ raise NotImplementedError("Fastest routing strategy is not implemented (yet)")
|
|
|
|
+ elif strategy == RoutingStrategy.LOAD_BALANCED:
|
|
|
|
+ raise NotImplementedError("Load-balanced routing strategy is not implemented (yet)")
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
- return span_sequence
|
|
|
|
|
|
|
|
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
|
def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
|
|
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
|
|
"""Get a RemoteSequenceManager for a sub-sequence of blocks"""
|
|
assert isinstance(ix, (int, slice))
|
|
assert isinstance(ix, (int, slice))
|
|
if not isinstance(ix, slice):
|
|
if not isinstance(ix, slice):
|
|
ix = slice(int(ix), int(ix) + 1, 1)
|
|
ix = slice(int(ix), int(ix) + 1, 1)
|
|
- with self.lock_changes:
|
|
|
|
- subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
|
|
|
|
|
|
+
|
|
|
|
+ self.ready.wait()
|
|
|
|
+ with self._lock_changes:
|
|
|
|
+ subseq = RemoteSequenceManager(
|
|
|
|
+ self.dht,
|
|
|
|
+ self.block_uids[ix],
|
|
|
|
+ p2p=self.p2p,
|
|
|
|
+ max_retries=self.max_retries,
|
|
|
|
+ update_period=self.update_period,
|
|
|
|
+ start=False,
|
|
|
|
+ ) # NB: if you've added more parameters to __init__, please forward them in the instantiation above
|
|
subseq.block_infos = self.block_infos[ix]
|
|
subseq.block_infos = self.block_infos[ix]
|
|
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
|
|
subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
|
|
|
|
+ subseq._rpc_info = self._rpc_info
|
|
subseq.last_update_time = self.last_update_time
|
|
subseq.last_update_time = self.last_update_time
|
|
|
|
+ if self.is_alive():
|
|
|
|
+ subseq.run_in_background()
|
|
return subseq
|
|
return subseq
|
|
|
|
|
|
def update_(self):
|
|
def update_(self):
|
|
- with self.lock_changes:
|
|
|
|
|
|
+ with self._lock_changes:
|
|
self.update_block_infos_()
|
|
self.update_block_infos_()
|
|
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
|
self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
|
|
|
|
|
|
|
|
+ def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
|
|
|
|
+ """
|
|
|
|
+ Starts averager in a background process. if await_ready, this method will wait until background dht
|
|
|
|
+ is ready to process incoming requests or for :timeout: seconds max.
|
|
|
|
+ """
|
|
|
|
+ self.start()
|
|
|
|
+ if await_ready:
|
|
|
|
+ self.ready.wait(timeout)
|
|
|
|
+
|
|
def update_block_infos_(self):
|
|
def update_block_infos_(self):
|
|
new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
|
|
new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
|
|
assert len(new_block_infos) == len(self.block_uids)
|
|
assert len(new_block_infos) == len(self.block_uids)
|