瀏覽代碼

routing strategy interface

justheuristic 3 年之前
父節點
當前提交
32dc9649ca

+ 3 - 0
src/client/routing/__init__.py

@@ -0,0 +1,3 @@
+from src.client.routing.routing_strategy import *
+from src.client.routing.sequence_info import RemoteSequenceInfo
+from src.client.routing.sequence_manager import RemoteSequenceManager

+ 82 - 0
src/client/routing/routing_strategy.py

@@ -0,0 +1,82 @@
+"""RoutingStrategies are helpers for RemoteSequenceManager (sequence_manager.py) that implement make_sequence"""
+import random
+from abc import ABC
+from typing import List, Optional, Tuple
+
+from src.client.routing.sequence_info import RemoteSequenceInfo
+from src.data_structures import RemoteSpanInfo, ServerState
+
+
+class RoutingStrategyBase(ABC):
+    name: str  # used in RemoteSequenceManager.make_sequence(mode, **kwargs)
+
+    def update_(self):
+        """Called when sequence manager fetches new info from the dht"""
+        raise NotImplementedError()
+
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None, **kwargs):
+        """Form and return a sequence;"""
+        raise NotImplementedError()
+
+
+class RandomRoutingStrategy(RoutingStrategyBase):
+    """choose a random compatible server at each branch and include all layers served by it"""
+
+    name = "RANDOM"
+
+    def __init__(self, sequence_info: RemoteSequenceInfo):
+        self.sequence_info = sequence_info
+        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
+        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(sequence_info)))
+
+    def update_(self):
+        for uid, info in zip(self.sequence_info.block_uids, self.sequence_info.block_infos):
+            assert info is not None, f"Found no remote peers for block {uid}"
+            # TODO change this to waiting and warning - instead of crashing the thread :)
+
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(self.sequence_info.block_infos):
+            for peer_id, server in info.servers.items():
+                if server.state != ServerState.ONLINE:
+                    continue
+                if peer_id not in active_spans:
+                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                else:  # peer_id in active_spans
+                    active_spans[peer_id].end = block_index + 1
+
+            for peer_id in list(active_spans.keys()):
+                if (
+                    peer_id not in info.servers
+                    or info.servers[peer_id].state != ServerState.ONLINE
+                    or block_index == len(self.sequence_info.block_infos) - 1
+                ):
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+        self.spans_by_priority = closed_spans
+
+        spans_containing_block = tuple(list() for _ in range(len(self.sequence_info.block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        self.spans_containing_block = spans_containing_block
+        assert self.spans_by_priority and self.spans_containing_block
+
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None, **kwargs):
+        assert not kwargs, f"Unexpected kwargs: {kwargs}"
+        end_index = end_index if end_index is not None else len(self.sequence_info)
+        span_sequence = []
+        current_index = start_index
+        while current_index < end_index:
+            candidate_spans = self.spans_containing_block[current_index]
+            chosen_span = random.choice(candidate_spans)
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(chosen_span)
+            current_index = chosen_span.end
+        return span_sequence
+
+
+ALL_ROUTING_STRATEGIES = (RandomRoutingStrategy,)

+ 61 - 0
src/client/routing/sequence_info.py

@@ -0,0 +1,61 @@
+import dataclasses
+from typing import Iterable, Tuple, Type, TypeVar
+
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+
+from src.data_structures import ModuleUID, RemoteModuleInfo
+from src.dht_utils import get_remote_module_infos
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+T = TypeVar("T")
+
+
+@dataclasses.dataclass(frozen=True)
+class RemoteSequenceInfo:
+    """
+    A dataclass that stores general information about which servers hold any given layer;
+    - updated by RemoteSequenceManager in a background thread
+    - accessed by routing strategies in .on_update
+
+    :note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
+     Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
+     is used by most routing strategies should be moved from said strategies to this class.
+
+    """
+
+    block_uids: Tuple[ModuleUID, ...]
+    block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated
+
+    @classmethod
+    def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
+        block_uids = tuple(block_uids)
+        empty_block_infos = tuple(RemoteModuleInfo(uid, dict()) for uid in block_uids)
+        return cls(block_uids, empty_block_infos)
+
+    def __getitem__(self, ix: slice):
+        assert isinstance(ix, slice)
+        return RemoteSequenceInfo(self.block_uids[ix], self.block_infos[ix])
+
+    def __len__(self):
+        return len(self.block_uids)
+
+    def update_(self, dht: DHT):
+        new_block_infos = get_remote_module_infos(dht, self.block_uids, expiration_time=float("inf"))
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.warning(f"Found no block info for block {uid}")
+                continue
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+                continue
+            if not info.servers:
+                logger.warning(f"Found no active peers for block {uid}")
+                continue
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+                continue
+            self.block_infos[block_index].servers = info.servers

+ 45 - 94
src/client/sequence_manager.py → src/client/routing/sequence_manager.py

@@ -3,13 +3,15 @@ from __future__ import annotations
 import enum
 import random
 import threading
-from typing import List, Optional, Sequence, Tuple, Union
+from typing import Collection, Dict, List, Optional, Sequence, Tuple, Union
 
 from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
+from src.client.routing.routing_strategy import ALL_ROUTING_STRATEGIES, RoutingStrategyBase
+from src.client.routing.sequence_info import RemoteSequenceInfo
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.server.handler import TransformerConnectionHandler
@@ -18,12 +20,6 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-class RoutingStrategy(enum.Enum):
-    RANDOM = enum.auto()  # choose a random compatible server at each branch and include all layers served by it
-    FASTEST = enum.auto()  # [WIP] minimize the estimated time to process a given number of tokens, including latency
-    LOAD_BALANCED = enum.auto()  # [WIP] use servers in proportion to their speed, on average over many sequences
-
-
 class RemoteSequenceManager(threading.Thread):
     """
     Sequence manager is a thread that keeps track of information on remote servers that constitute a RemoteSequential.
@@ -34,6 +30,11 @@ class RemoteSequenceManager(threading.Thread):
 
     To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
 
+    :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
+    :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
+    :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
+    :param update_period: by default, refresh DHT information once in this many seconds
+
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
       running redundant sequence managers for the same set of layers.
 
@@ -53,71 +54,50 @@ class RemoteSequenceManager(threading.Thread):
         *,
         p2p: Optional[P2P] = None,
         start: bool,
-        max_retries: int = 3,
         update_period: float = 30,
+        routing_strategies: Collection[RoutingStrategyBase] = None,
     ):  # NB: if you add any more parameters, please make sure you pass them to sub-sequences in .__getitem__ below!
         super().__init__(daemon=True)
         self.dht, self.p2p = dht, (p2p if p2p is not None else dht.replicate_p2p())
-        self.block_uids: List[ModuleUID] = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
-        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
-        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+        self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)  # to be updated in a background thread
+
+        if routing_strategies is None:
+            routing_strategies = [Strategy(self.sequence_info) for Strategy in ALL_ROUTING_STRATEGIES]
+        self.routing_strategies: Dict[str, RoutingStrategyBase] = {s.name: s for s in routing_strategies}
 
-        self.update_period, self.max_retries = update_period, max_retries
         self.last_update_time: DHTExpiration = -float("inf")
+        self.update_period = update_period
 
         self._rpc_info = None
         self._lock_changes = threading.Lock()
-        self.ready = threading.Event()  # whether or not this thread is ready to make_sequence
+        self.ready = threading.Event()  # whether or not you are ready to make_sequence
 
         if start:
             self.run_in_background()
 
-        for uid, info in zip(self.block_uids, self.block_infos):
-            assert info is not None, f"Found no remote peers for block {uid}"
-        assert self.spans_by_priority and self.spans_containing_block
+    def run(self) -> None:
+        self.ready.set()
+        # TODO
 
     def make_sequence(
-        self,
-        start_index: int = 0,
-        end_index: Optional[int] = None,
-        strategy: RoutingStrategy = RoutingStrategy.RANDOM,
-        num_tokens: Optional[int] = None,
+        self, strategy: Union[str, RoutingStrategyBase], start_index: int = 0, end_index: Optional[int] = None, **kwargs
     ) -> Sequence[RemoteSpanInfo]:
         """
         Form a sequence of remote servers that collectively serve all consecutive layers
 
+        :param strategy: the routing algorithm to use (e.g. random, fastest, balanced), see routing_strategy.py
         :param start_index: optional index of the first module in a sequence, default = the first of block_uids
         :param end_index: optional index of the last module (non-inclusive), default = after last of block_uids
-        :param strategy: the routing algorithm to use (e.g. random, fastest, balanced), see RoutingStrategy for details
-        :param num_tokens: the number of tokens sent through this sequence at a time, used by RoutingStrategy.FASTEST
+        :param kwargs: additional keyword arguments, depending on your chosen routing strategy
         """
         assert self.is_alive()
         if not self.ready.is_set():
             logger.warning(f"{self.__class__.__name__} is still initializing, waiting until it's ready...")
             self.ready.wait()
             logger.warning(f"Finished waiting for {self.__class__.__name__} to initialize")
-        if (strategy is RoutingStrategy.FASTEST) != (num_tokens is not None):
-            logger.warning("please specify num_tokens with FASTEST strategy (and only with FASTEST strategy)")
-        end_index = end_index if end_index is not None else len(self.block_uids)
-
-        if strategy == RoutingStrategy.RANDOM:
-            span_sequence = []
-            current_index = start_index
-            while current_index < end_index:
-                candidate_spans = self.spans_containing_block[current_index]
-                chosen_span = random.choice(candidate_spans)
-                assert chosen_span.start <= current_index < chosen_span.end
-                span_sequence.append(chosen_span)
-                current_index = chosen_span.end
-            return span_sequence
-        elif strategy == RoutingStrategy.FASTEST:
-            raise NotImplementedError("Fastest routing strategy is not implemented (yet)")
-        elif strategy == RoutingStrategy.LOAD_BALANCED:
-            raise NotImplementedError("Load-balanced routing strategy is not implemented (yet)")
-
-
-
+        if not isinstance(strategy, RoutingStrategyBase):
+            strategy = self.routing_strategies[strategy]
+        return strategy.make_sequence(start_index, end_index, **kwargs)
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
@@ -131,11 +111,10 @@ class RemoteSequenceManager(threading.Thread):
                 self.dht,
                 self.block_uids[ix],
                 p2p=self.p2p,
-                max_retries=self.max_retries,
                 update_period=self.update_period,
                 start=False,
             )  # NB: if you've added more parameters to __init__, please forward them in the instantiation above
-            subseq.block_infos = self.block_infos[ix]
+            subseq.sequence_info = self.sequence_info[ix]
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq._rpc_info = self._rpc_info
             subseq.last_update_time = self.last_update_time
@@ -145,8 +124,9 @@ class RemoteSequenceManager(threading.Thread):
 
     def update_(self):
         with self._lock_changes:
-            self.update_block_infos_()
-            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+            self.sequence_info.update_(self.dht)
+            for name, strategy in self.routing_strategies:
+                strategy.update_()
 
     def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
@@ -157,54 +137,17 @@ class RemoteSequenceManager(threading.Thread):
         if await_ready:
             self.ready.wait(timeout)
 
-    def update_block_infos_(self):
-        new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
-        assert len(new_block_infos) == len(self.block_uids)
-        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.warning(f"Found no block info for block {uid}")
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.servers:
-                logger.warning(f"Found no active peers for block {uid}")
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            self.block_infos[block_index] = info
-
-    @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            for peer_id, server in info.servers.items():
-                if server.state != ServerState.ONLINE:
-                    continue
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id].end = block_index + 1
-
-            for peer_id in list(active_spans.keys()):
-                if (
-                    peer_id not in info.servers
-                    or info.servers[peer_id].state != ServerState.ONLINE
-                    or block_index == len(block_infos) - 1
-                ):
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
-
-        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
-
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
-            for block_index in range(span.start, span.end):
-                spans_containing_block[block_index].append(span)
-
-        return closed_spans, spans_containing_block
-
     def __len__(self):
         return len(self.block_uids)
 
+    @property
+    def block_uids(self) -> Sequence[ModuleUID]:
+        return self.sequence_info.block_uids
+
+    @property
+    def block_infos(self) -> Sequence[RemoteModuleInfo]:
+        return self.sequence_info.block_infos
+
     @property
     def rpc_info(self):
         """Return the rpc_info queried from one of the servers that hold the first block"""
@@ -226,3 +169,11 @@ class RemoteSequenceManager(threading.Thread):
                     else:
                         logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
         return self._rpc_info
+
+    @property
+    def max_retries(self) -> int:
+        logger.warning(
+            "RemoteSequenceManager.max_retries is deprecated and will be removed when dbaranchuk@ implements"
+            " chained forward/backward. If you have questions about the roadmap, please ping yozh@ ."
+        )
+        return 3