justheuristic 3 rokov pred
rodič
commit
c7eda494bf
2 zmenil súbory, kde vykonal 104 pridanie a 23 odobranie
  1. 1 1
      src/client/remote_sequential.py
  2. 103 22
      src/client/sequence_manager.py

+ 1 - 1
src/client/remote_sequential.py

@@ -43,7 +43,7 @@ class RemoteSequential(nn.Module):
         block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, p2p=self.p2p, start=True)
             self.is_subsequence = False
         else:
             logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")

+ 103 - 22
src/client/sequence_manager.py

@@ -1,5 +1,6 @@
 from __future__ import annotations
 
+import enum
 import random
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
@@ -17,65 +18,145 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RemoteSequenceManager:
+class RoutingStrategy(enum.Enum):
+    RANDOM = enum.auto()  # choose a random compatible server at each branch and include all layers served by it
+    FASTEST = enum.auto()  # [WIP] minimize the estimated time to process a given number of tokens, including latency
+    LOAD_BALANCED = enum.auto()  # [WIP] use servers in proportion to their speed, on average over many sequences
+
+
+class RemoteSequenceManager(threading.Thread):
     """
-    Keeps and updates the meta-information about which peers host which blocks.
-    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+    Sequence manager is a thread that keeps track of information on remote servers that constitute a RemoteSequential.
+    TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
+
+    When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
+    Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
+
+    To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
+
+    :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
+      running redundant sequence managers for the same set of layers.
+
+    Example
+    =======
+    >>> sequence_manager = RemoteSequenceManager(dht=..., block_uids=('me/my-model.0', 'me/my-model.1', 'me/my-model.2')
+    >>> seq1_full_model = sequence_manager.make_sequence()
+    >>> seq2_partial = sequence_manager.make_sequence(start_index=0, end_index=2)  # the end index is exclusive
+    >>> seq1_fastest = sequence_manager.make_sequence()
+
     """
 
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
-        self.dht, self.p2p = dht, p2p
+    def __init__(
+        self,
+        dht: DHT,
+        block_uids: Sequence[ModuleUID],
+        *,
+        p2p: Optional[P2P] = None,
+        start: bool,
+        max_retries: int = 3,
+        update_period: float = 30,
+    ):  # NB: if you add any more parameters, please make sure you pass them to sub-sequences in .__getitem__ below!
+        super().__init__(daemon=True)
+        self.dht, self.p2p = dht, (p2p if p2p is not None else dht.replicate_p2p())
         self.block_uids: List[ModuleUID] = list(block_uids)
         self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
         self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
         self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+
+        self.update_period, self.max_retries = update_period, max_retries
         self.last_update_time: DHTExpiration = -float("inf")
-        self.max_retries = max_retries
+
         self._rpc_info = None
-        self.lock_changes = threading.Lock()
-        self.update_()
+        self._lock_changes = threading.Lock()
+        self.ready = threading.Event()  # whether or not this thread is ready to make_sequence
+
+        if start:
+            self.run_in_background()
 
         for uid, info in zip(self.block_uids, self.block_infos):
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
 
-    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+    def make_sequence(
+        self,
+        start_index: int = 0,
+        end_index: Optional[int] = None,
+        strategy: RoutingStrategy = RoutingStrategy.RANDOM,
+        num_tokens: Optional[int] = None,
+    ) -> Sequence[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
 
         :param start_index: optional index of the first module in a sequence, default = the first of block_uids
-        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block_uids
+        :param strategy: the routing algorithm to use (e.g. random, fastest, balanced), see RoutingStrategy for details
+        :param num_tokens: the number of tokens sent through this sequence at a time, used by RoutingStrategy.FASTEST
         """
+        assert self.is_alive()
+        if not self.ready.is_set():
+            logger.warning(f"{self.__class__.__name__} is still initializing, waiting until it's ready...")
+            self.ready.wait()
+            logger.warning(f"Finished waiting for {self.__class__.__name__} to initialize")
+        if (strategy is RoutingStrategy.FASTEST) != (num_tokens is not None):
+            logger.warning("please specify num_tokens with FASTEST strategy (and only with FASTEST strategy)")
         end_index = end_index if end_index is not None else len(self.block_uids)
-        span_sequence = []
-        current_index = start_index
-        while current_index < end_index:
-            candidate_spans = self.spans_containing_block[current_index]
-            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
 
-            assert chosen_span.start <= current_index < chosen_span.end
-            span_sequence.append(chosen_span)
-            current_index = chosen_span.end
+        if strategy == RoutingStrategy.RANDOM:
+            span_sequence = []
+            current_index = start_index
+            while current_index < end_index:
+                candidate_spans = self.spans_containing_block[current_index]
+                chosen_span = random.choice(candidate_spans)
+                assert chosen_span.start <= current_index < chosen_span.end
+                span_sequence.append(chosen_span)
+                current_index = chosen_span.end
+            return span_sequence
+        elif strategy == RoutingStrategy.FASTEST:
+            raise NotImplementedError("Fastest routing strategy is not implemented (yet)")
+        elif strategy == RoutingStrategy.LOAD_BALANCED:
+            raise NotImplementedError("Load-balanced routing strategy is not implemented (yet)")
+
+
 
-        return span_sequence
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
-        with self.lock_changes:
-            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
+
+        self.ready.wait()
+        with self._lock_changes:
+            subseq = RemoteSequenceManager(
+                self.dht,
+                self.block_uids[ix],
+                p2p=self.p2p,
+                max_retries=self.max_retries,
+                update_period=self.update_period,
+                start=False,
+            )  # NB: if you've added more parameters to __init__, please forward them in the instantiation above
             subseq.block_infos = self.block_infos[ix]
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
+            subseq._rpc_info = self._rpc_info
             subseq.last_update_time = self.last_update_time
+            if self.is_alive():
+                subseq.run_in_background()
         return subseq
 
     def update_(self):
-        with self.lock_changes:
+        with self._lock_changes:
             self.update_block_infos_()
             self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
 
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts averager in a background process. if await_ready, this method will wait until background dht
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self.start()
+        if await_ready:
+            self.ready.wait(timeout)
+
     def update_block_infos_(self):
         new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
         assert len(new_block_infos) == len(self.block_uids)