فهرست منبع

Optimize RemoteSequenceManager (#106)

- [x] made RemoteSequenceManager into a background thread that pre-fetches information instead of running just in time
- [x] moved routing-related stuff to petals.client.routing
- [x] extract remote peer routing information to RemoteSequenceInfo
- [x] made sure that the code survives continued use (e.g. one hour)
- [x] updated every spot where update_ is called manually
- [x] modified get_sequence to check that the thread is alive, warn if not
- [x] removed max_retries, switched rpc_info to exponential backoff
- [x] fixed a bg that causes RemoteSeq* to lose user-defined hyperparameters (e.g. timeout) upon subsequencing (sequential[3:5])
- [x] moved client-side points strategy to client.routing
- [x] ensured that RemoteSequenceManager thread created in get_remote_module properly shuts down when the module is destroyed
- [x] resolved minor affected todos
- [x] modified tests to no longer use PYTHONPATH
- [x] worked around protocol error in rpc_info


Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Artem Chumachenko <artek.chumak@gmail.com>
justheuristic 2 سال پیش
والد
کامیت
a2066a4096

+ 1 - 1
.github/workflows/run-tests.yaml

@@ -104,7 +104,7 @@ jobs:
 
           kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
 
-          PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
+          pytest tests --durations=0 --durations-min=1.0 -v
 
           kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
 

+ 2 - 2
src/petals/client/__init__.py

@@ -1,5 +1,5 @@
 from petals.client.inference_session import InferenceSession
 from petals.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
-from petals.client.sequence_manager import RemoteSequenceManager
-from petals.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
+from petals.client.routing.sequence_manager import RemoteSequenceManager
+from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 7 - 7
src/petals/client/inference_session.py

@@ -20,7 +20,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.p2p import StubBase
 from hivemind.proto import runtime_pb2
 
-from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
@@ -44,14 +44,14 @@ class _ServerInferenceSession:
         *,
         timeout: float,
         max_length: int,
-        points: int = 0,
+        **metadata,
     ):
         self.uid, self.rpc_info = uid, rpc_info
         self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
         self.timeout = timeout
-        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
+        self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, **metadata))
         self.stepped = False
         self.closed = False
 
@@ -162,7 +162,7 @@ class InferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int):
         self._sequence_manager = sequence_manager
         self._p2p = p2p
         self._closed = False
@@ -171,7 +171,6 @@ class InferenceSession:
         self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
         self._max_length = max_length
-        self._metadata = metadata
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []
@@ -179,6 +178,7 @@ class InferenceSession:
             for span in chosen_spans:
                 stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
+                metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 session = RemoteExpertWorker.run_coroutine(
                     _ServerInferenceSession.create(
                         stub,
@@ -186,7 +186,7 @@ class InferenceSession:
                         rpc_info=self._sequence_manager.rpc_info,
                         timeout=self._sequence_manager.request_timeout,
                         max_length=self._max_length,
-                        **self._metadata,
+                        **metadata,
                     )
                 )
                 server_sessions.append(session)
@@ -237,7 +237,7 @@ class InferenceSession:
                 logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
                 try:
                     if attempt_no >= 1:
-                        self._sequence_manager.update_()
+                        self._sequence_manager.update(wait=True)
                     if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
                         # If there is a failed server session, this code closes it
                         self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])

+ 1 - 1
src/petals/client/remote_model.py

@@ -36,7 +36,7 @@ class DistributedBloomConfig(BloomConfig):
     chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
-    request_timeout: int = 20  # a number of seconds for waiting result from each node
+    request_timeout: int = 30  # a number of seconds for waiting result from each node
 
 
 original_register_parameter = nn.Module.register_parameter

+ 8 - 7
src/petals/client/remote_sequential.py

@@ -9,7 +9,7 @@ from torch import nn
 
 import petals.client
 from petals.client.inference_session import InferenceSession
-from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
 from petals.utils.misc import DUMMY
@@ -30,7 +30,7 @@ class RemoteSequential(nn.Module):
         dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
-        request_timeout: int = 20,
+        **kwargs,
     ):
         super().__init__()
         self.config = config
@@ -39,16 +39,18 @@ class RemoteSequential(nn.Module):
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
         num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
-        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
+        block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, request_timeout=request_timeout)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p, start=True, **kwargs)
             self.is_subsequence = False
         else:
             logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+            if kwargs:
+                logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided")
             self.sequence_manager = sequence_manager
-            assert isinstance(sequence_manager.block_uids, list)
-            self.is_subsequence = self.sequence_manager.block_uids != block_uids
+            assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
+            self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
@@ -81,7 +83,6 @@ class RemoteSequential(nn.Module):
         return len(self.sequence_manager)
 
     def inference_session(self, **kwargs) -> InferenceSession:
-        self.sequence_manager.update_()
         return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
 
     def extra_repr(self) -> str:

+ 1 - 0
src/petals/client/routing/__init__.py

@@ -0,0 +1 @@
+"""Client-side functions responsible for choosing the best server, """

+ 102 - 0
src/petals/client/routing/sequence_info.py

@@ -0,0 +1,102 @@
+import dataclasses
+import time
+from typing import Iterable, List, Optional, Sequence, Tuple, Type, TypeVar
+
+from hivemind import get_logger, use_hivemind_log_handler
+
+from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+T = TypeVar("T")
+
+
+@dataclasses.dataclass
+class RemoteSequenceInfo:
+    """
+    A dataclass that stores general information about which servers hold any given layer;
+    - updated by RemoteSequenceManager in a background thread
+    - accessed by routing strategies in .on_update
+    :note: this class should *not* be modified by RoutingStrategy.on_update to avoid interference between strategies;
+     Any metadata specific to one routing strategy, it should be stored inside that strategy. Any information that
+     is used by most routing strategies should be moved from said strategies to this class.
+    """
+
+    block_uids: Tuple[ModuleUID, ...]
+    block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated
+    spans_by_priority: List[RemoteSpanInfo]
+    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
+    last_updated_time: float
+
+    @classmethod
+    def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
+        block_uids = tuple(block_uids)
+        empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
+        empty_spans = tuple([] for _ in range(len(block_uids)))
+        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
+
+    def __getitem__(self, ix: slice):
+        assert isinstance(ix, slice)
+        block_uids, block_infos = self.block_uids[ix], self.block_infos[ix]
+        spans_by_priority, spans_containing_block = self.compute_spans(block_infos)
+        return RemoteSequenceInfo(
+            block_uids, block_infos, spans_by_priority, spans_containing_block, self.last_updated_time
+        )
+
+    def __len__(self):
+        return len(self.block_uids)
+
+    def update_(self, new_block_infos: List[Optional[RemoteModuleInfo]]):
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.debug(f"Found no block info for block {uid}")
+                continue
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+                continue
+            if not info.servers:
+                logger.debug(f"Found no active peers for block {uid}")
+                continue
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+                continue
+            self.block_infos[block_index].servers = info.servers
+
+        self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+        self.last_updated_time = time.perf_counter()
+
+    @staticmethod
+    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(block_infos):
+            if info is not None:
+                for peer_id, server in info.servers.items():
+                    if server.state != ServerState.ONLINE:
+                        continue
+                    if peer_id not in active_spans:
+                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    else:  # peer_id in active_spans
+                        active_spans[peer_id].end = block_index + 1
+
+            for peer_id in list(active_spans.keys()):
+                if (
+                    info is None
+                    or peer_id not in info.servers
+                    or info.servers[peer_id] != ServerState.ONLINE
+                    or block_index == len(block_infos) - 1
+                ):
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans, f"spans: {active_spans}"
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        return closed_spans, spans_containing_block

+ 265 - 0
src/petals/client/routing/sequence_manager.py

@@ -0,0 +1,265 @@
+from __future__ import annotations
+
+import itertools
+import logging
+import random
+import threading
+import time
+from typing import Any, Dict, List, Optional, Sequence, Union
+from weakref import WeakMethod
+
+from hivemind import DHT, P2P, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+
+import petals.dht_utils
+from petals.client.routing.sequence_info import RemoteSequenceInfo
+from petals.client.routing.spending_policy import NoSpendingPolicy
+from petals.data_structures import ModuleUID, RemoteSpanInfo
+from petals.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequenceManager:
+    """
+    Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
+    TL;DR it tells you, which peers you should ask to get a specific layer. It is used in RemoteSequential.
+    When created, RemoteSequenceManager looks up which servers serve necessary layers by reading from DHT.
+    Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
+    To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
+
+    :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
+    :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
+    :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
+    :param update_period: by default, refresh DHT information once in this many seconds
+    :param request_timeout: float, in seconds, default timeout for RPC forwad/backward/inference requests
+    :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
+    :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
+    :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
+    :param start: start the background thread (see the note below). If false, you will need to start it manually.
+    :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
+      running redundant sequence managers for the same set of layers.
+
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        block_uids: Sequence[ModuleUID],
+        p2p: P2P,
+        update_period: float = 30,
+        request_timeout: float = 30,
+        min_backoff: float = 1,
+        sequence_info: Optional[RemoteSequenceInfo] = None,
+        rpc_info: Optional[dict] = None,
+        *,  # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
+        start: bool,
+    ):
+        assert len(block_uids) > 0, "Sequences must contain at least one block"
+        self.dht, self.p2p = dht, p2p
+        self.request_timeout, self.min_backoff = request_timeout, min_backoff
+        self.lock_changes = threading.Lock()
+        self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
+        self.policy = NoSpendingPolicy()
+        self._rpc_info = rpc_info
+
+        if sequence_info is None:
+            self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
+            self.update(wait=False)
+        else:
+            self.sequence_info = sequence_info
+            assert block_uids == sequence_info.block_uids
+            self._thread.ready.set()  # no need to await the first dht fetch
+
+        if start:
+            self.run_in_background()
+
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
+        """
+        Starts the updater thread in a background. if await_ready, this method will wait until sequence manager
+        is ready to process incoming requests or for :timeout: seconds max.
+        """
+        self._thread.start()
+        if await_ready:
+            self._thread.ready.wait(timeout)
+
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
+        """
+        Form a sequence of remote servers that collectively serve all consecutive layers
+
+        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        """
+        if not self.is_alive():
+            logger.error("Using a sequence manager that is not running: it has either crashed or never started")
+        if not self.ready.is_set():
+            logger.warning("Remote SequenceManager is still searching for routes, waiting for it to become ready")
+            self.ready.wait()
+
+        end_index = end_index if end_index is not None else len(self)
+        span_sequence = []
+        current_index = start_index
+        while current_index < end_index:
+            candidate_spans = self.sequence_info.spans_containing_block[current_index]
+            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
+            current_index = chosen_span.end
+
+        return span_sequence
+
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
+        """Get a RemoteSequenceManager for a sub-sequence of blocks"""
+        assert isinstance(ix, (int, slice))
+        if not isinstance(ix, slice):
+            ix = slice(int(ix), int(ix) + 1, 1)
+        return type(self)(
+            self.dht,
+            self.block_uids[ix],
+            self.p2p,
+            update_period=self._thread.update_period,
+            request_timeout=self.request_timeout,
+            min_backoff=self.min_backoff,
+            sequence_info=self.sequence_info[ix],
+            rpc_info=self._rpc_info,
+            start=True,
+        )
+
+    def update(self, *, wait: bool):
+        """Run an asynchronous update in background as soon as possible"""
+        self.ready.clear()  # TODO this should be a separate event
+        self._thread.trigger.set()
+        if wait:
+            self.ready.wait()
+
+    def _update(self):
+        """Perform an immediate and synchronous refresh, may take time"""
+        for attempt_no in itertools.count():
+            try:
+                new_block_infos = petals.dht_utils.get_remote_module_infos(
+                    self.dht, self.block_uids, expiration_time=float("inf")
+                )
+                with self.lock_changes:
+                    self.sequence_info.update_(new_block_infos)
+                missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
+                if missing_blocks:
+                    raise MissingBlocksError(f"no servers holding blocks {missing_blocks}")
+                self.ready.set()  # if there is an active server for every block, we may begin running
+                break
+
+            except Exception as e:
+                delay = self.get_retry_delay(attempt_no)
+                logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
+                traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                time.sleep(delay)
+
+    def __len__(self):
+        return len(self.block_uids)
+
+    @property
+    def is_alive(self):
+        return self._thread.is_alive
+
+    @property
+    def ready(self) -> threading.Event:
+        return self._thread.ready
+
+    @property
+    def block_uids(self):
+        return self.sequence_info.block_uids
+
+    @property
+    def rpc_info(self):
+        """Return the rpc_info queried from one of the servers that hold the first block"""
+        if self._rpc_info is None:
+            for attempt_no in itertools.count():
+                try:
+                    self._update()
+                    peer_id, _ = random.choice(list(self.sequence_info.block_infos[0].servers.items()))
+                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+                    outputs = RemoteExpertWorker.run_coroutine(
+                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+                    )
+                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                    break
+                except Exception as e:
+                    delay = self.get_retry_delay(attempt_no)
+                    logger.warning(
+                        f"Caught exception when gathering information from peer {peer_id} "
+                        f"(retry in {delay:.0f} sec): {repr(e)}"
+                    )
+                    traceback_level = logging.DEBUG if str(e) else logging.WARNING
+                    logger.log(traceback_level, "See detailed traceback below:", exc_info=True)
+                    time.sleep(delay)
+
+        return self._rpc_info
+
+    def get_retry_delay(self, attempt_no: int) -> float:
+        if attempt_no == 0:
+            return 0
+        return self.min_backoff * 2 ** (attempt_no - 1)
+
+    def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
+        """
+        :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
+        :param args: request-specific inputs, typicall block uids and input tensors
+        :param kwargs: additional request context, such as remote peer ID
+        :returns: msgpack-serialized metadata dict that will be passed alongside a given request
+        """
+        return dict(points=self.policy.get_points(protocol, *args, **kwargs))
+
+    def shutdown(self):
+        self._thread.shutdown()
+
+
+class _SequenceManagerUpdateThread(threading.Thread):
+    def __init__(self, update_period: float, ref_update_manager: WeakMethod):
+        super().__init__(daemon=True)
+        self.ref_update_manager = ref_update_manager
+        self.ready = threading.Event()
+        self.trigger = threading.Event()
+        self.last_update_time = -float("inf")
+        self.update_period = update_period
+        self.should_shutdown = False
+
+    def run(self) -> None:
+        while not self.should_shutdown:
+            self.trigger.wait(max(0.0, min(self.update_period, time.perf_counter() - self.last_update_time)))
+
+            if self.should_shutdown:
+                logger.debug(f"{self.__class__.__name__} is shutting down")
+                break
+
+            update_manager = self.ref_update_manager()
+            if update_manager is None:
+                logger.debug(f"{self.__class__.__name__} exited because the sequence manager no longer exists")
+                break
+
+            try:
+                update_manager()
+                self.trigger.clear()
+            except Exception as e:
+                logger.exception(e)
+            finally:
+                del update_manager
+
+        logger.debug(f"{self.__class__.__name__} thread exited")
+
+    def shutdown(self, timeout: Optional[float] = None):
+        self.should_shutdown = True
+        self.trigger.set()
+        self.join(timeout)
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()
+
+
+class MissingBlocksError(Exception):
+    def __repr__(self):
+        return self.args[0]

+ 0 - 0
src/petals/client/spending_policy.py → src/petals/client/routing/spending_policy.py


+ 0 - 179
src/petals/client/sequence_manager.py

@@ -1,179 +0,0 @@
-from __future__ import annotations
-
-import random
-import threading
-from typing import List, Optional, Sequence, Tuple, Union
-
-from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
-from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.proto import runtime_pb2
-from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-
-import petals.dht_utils
-from petals.client.spending_policy import NoSpendingPolicy
-from petals.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
-from petals.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteSequenceManager:
-    """
-    Keeps and updates the meta-information about which peers host which blocks.
-    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
-    """
-
-    def __init__(
-        self,
-        dht: DHT,
-        block_uids: Sequence[ModuleUID],
-        p2p: P2P,
-        max_retries: int = 3,
-        request_timeout: float = 20,
-        min_backoff: float = 1,
-    ):
-        assert len(block_uids) > 0, "Sequences must contain at least one block"
-        self.dht, self.p2p = dht, p2p
-        self.block_uids: List[ModuleUID] = list(block_uids)
-        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
-        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
-        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
-        self.last_update_time: DHTExpiration = -float("inf")
-        self.max_retries = max_retries
-        self.request_timeout, self.min_backoff = request_timeout, min_backoff
-        self._rpc_info = None
-        self.lock_changes = threading.Lock()
-        self.policy = NoSpendingPolicy()
-        self.update_()
-
-        for uid, info in zip(self.block_uids, self.block_infos):
-            assert info is not None, f"Found no remote peers for block {uid}"
-        assert self.spans_by_priority and self.spans_containing_block
-
-    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
-        """
-        Form a sequence of remote servers that collectively serve all consecutive layers
-
-        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
-        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
-        """
-        end_index = end_index if end_index is not None else len(self.block_uids)
-        span_sequence = []
-        current_index = start_index
-        while current_index < end_index:
-            candidate_spans = self.spans_containing_block[current_index]
-            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
-
-            assert chosen_span.start <= current_index < chosen_span.end
-            span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
-            current_index = chosen_span.end
-
-        return span_sequence
-
-    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
-        """Get a RemoteSequenceManager for a sub-sequence of blocks"""
-        assert isinstance(ix, (int, slice))
-        if not isinstance(ix, slice):
-            ix = slice(int(ix), int(ix) + 1, 1)
-        with self.lock_changes:
-            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
-            subseq.block_infos = self.block_infos[ix]
-            subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
-            subseq.last_update_time = self.last_update_time
-        return subseq
-
-    def update_(self):
-        with self.lock_changes:
-            self.update_block_infos_()
-            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
-
-    def update_block_infos_(self):
-        new_block_infos = petals.dht_utils.get_remote_module_infos(
-            self.dht, self.block_uids, expiration_time=float("inf")
-        )
-        assert len(new_block_infos) == len(self.block_uids)
-        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
-            if info is None:
-                logger.warning(f"Found no block info for block {uid}")
-                continue
-            if not isinstance(info, RemoteModuleInfo):
-                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
-            if not info.servers:
-                logger.warning(f"Found no active peers for block {uid}")
-            if info.uid != uid:
-                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
-            self.block_infos[block_index] = info
-
-    @staticmethod
-    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
-        closed_spans = []
-        active_spans = {}
-        for block_index, info in enumerate(block_infos):
-            if info is not None:
-                for peer_id, server in info.servers.items():
-                    if server.state != ServerState.ONLINE:
-                        continue
-                    if peer_id not in active_spans:
-                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
-                    else:  # peer_id in active_spans
-                        active_spans[peer_id].end = block_index + 1
-
-            for peer_id in list(active_spans.keys()):
-                if (
-                    info is None
-                    or peer_id not in info.servers
-                    or info.servers[peer_id].state != ServerState.ONLINE
-                    or block_index == len(block_infos) - 1
-                ):
-                    closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans, f"spans: {active_spans}"
-
-        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
-
-        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
-        for span in closed_spans:
-            for block_index in range(span.start, span.end):
-                spans_containing_block[block_index].append(span)
-
-        return closed_spans, spans_containing_block
-
-    def __len__(self):
-        return len(self.block_uids)
-
-    @property
-    def rpc_info(self):
-        """Return the rpc_info queried from one of the servers that hold the first block"""
-        if self._rpc_info is None:
-            retries = 0
-            for i in range(self.max_retries):
-                try:
-                    self.update_()
-                    peer_id = random.choice(list(self.block_infos[0].servers.keys()))
-                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
-                    outputs = RemoteExpertWorker.run_coroutine(
-                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
-                    )
-                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
-                    break
-                except Exception as e:
-                    retries += 1
-                    if retries >= self.max_retries:
-                        raise e
-                    else:
-                        logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
-        return self._rpc_info
-
-    def get_retry_delay(self, attempt_no: int) -> float:
-        if attempt_no == 0:
-            return 0
-        return self.min_backoff * 2 ** (attempt_no - 1)
-
-    def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[bytes]:
-        """
-        :param protocol: one of "rpc_forward", "rpc_backward" or "rpc_inference"
-        :param args: request-specific inputs, typicall block uids and input tensors
-        :param kwargs: additional request context, such as remote peer ID
-        :returns: msgpack-serialized metadata dict that will be passed alongside a given request
-        """
-        return MSGPackSerializer.dumps(dict(points=self.policy.get_points(protocol, *args, **kwargs)))

+ 6 - 5
src/petals/client/sequential_autograd.py

@@ -8,11 +8,12 @@ from collections import deque
 from typing import List, Optional, Sequence, Tuple
 
 import torch
+from hivemind import MSGPackSerializer
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.utils.logging import get_logger
 
 from petals.client.remote_forward_backward import run_remote_backward, run_remote_forward
-from petals.client.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
 from petals.server.handler import TransformerConnectionHandler
 from petals.utils.misc import DUMMY, is_dummy
@@ -58,7 +59,7 @@ async def sequential_forward(
             logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
             try:
                 if attempt_no >= 1:
-                    sequence_manager.update_()
+                    sequence_manager._update()
                 if not sequences or attempt_no >= 1:
                     sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
                     # make_sequence() could return a longer sequence
@@ -78,7 +79,7 @@ async def sequential_forward(
                     sequence_manager.rpc_info,
                     *inputs_and_prompts,
                     timeout=sequence_manager.request_timeout,
-                    metadata=metadata,
+                    metadata=MSGPackSerializer.dumps(metadata),
                 )
 
                 assert isinstance(outputs, torch.Tensor)
@@ -136,7 +137,7 @@ async def sequential_backward(
             logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
             try:
                 if attempt_no >= 1:
-                    sequence_manager.update_()
+                    sequence_manager.update(wait=True)
                     _, backup_inputs, backup_sequences = await sequential_forward(
                         inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
                     )
@@ -162,7 +163,7 @@ async def sequential_backward(
                     grad_outputs,
                     prompts[span.start : span.end],
                     timeout=sequence_manager.request_timeout,
-                    metadata=metadata,
+                    metadata=MSGPackSerializer.dumps(metadata),
                 )
                 grad_outputs = [grad_outputs]
                 grad_prompts_reversed.extend(span_grad_prompts)

+ 5 - 6
src/petals/dht_utils.py

@@ -94,7 +94,7 @@ async def _get_remote_sequence(
 ) -> petals.client.RemoteSequential:
     uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
     p2p = await dht.replicate_p2p()
-    manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
+    manager = petals.client.RemoteSequenceManager(dht, uids, p2p, start=True)
     return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
 
 
@@ -125,7 +125,7 @@ async def _get_remote_module(
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     p2p = await dht.replicate_p2p()
-    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p, start=True) for uid in uids)
     modules = [
         petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
         for m in managers
@@ -134,14 +134,13 @@ async def _get_remote_module(
 
 
 def get_remote_module_infos(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    expiration_time: Optional[DHTExpiration] = None,
+    dht: DHT, uid_or_uids: Union[ModuleUID, Sequence[ModuleUID]], expiration_time: Optional[DHTExpiration] = None
 ) -> List[Optional[RemoteModuleInfo]]:
     single_uid = isinstance(uid_or_uids, ModuleUID)
     uids = [uid_or_uids] if single_uid else uid_or_uids
     infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
+        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time),
+        return_future=False,
     )
     return infos[0] if single_uid else infos
 

+ 3 - 3
src/petals/server/block_selection.py

@@ -25,7 +25,7 @@ class Span:
         self.start, self.end = new_start, new_start + self.length
 
 
-def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
+def compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
     spans = {}
     throughputs = np.zeros(len(module_infos))
     for block, module in enumerate(module_infos):
@@ -56,7 +56,7 @@ def _choose_best_start(throughputs: np.ndarray, num_blocks: int) -> int:
 
 
 def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
-    _, throughputs = _compute_spans(module_infos)
+    _, throughputs = compute_spans(module_infos)
     start = _choose_best_start(throughputs, num_blocks)
     return list(range(start, start + num_blocks))
 
@@ -67,7 +67,7 @@ def should_choose_other_blocks(
     if balance_quality > 1.0:
         return True  # Forces rebalancing on each check (may be used for debugging purposes)
 
-    spans, throughputs = _compute_spans(module_infos)
+    spans, throughputs = compute_spans(module_infos)
     initial_throughput = throughputs.min()
     eps = 1e-3
 

+ 8 - 8
tests/test_remote_sequential.py

@@ -1,6 +1,6 @@
 import pytest
 import torch
-from hivemind import DHT, BatchTensorDescriptor, MSGPackSerializer, get_logger, use_hivemind_log_handler
+from hivemind import DHT, BatchTensorDescriptor, get_logger, use_hivemind_log_handler
 from hivemind.proto import runtime_pb2
 from test_utils import *
 
@@ -48,7 +48,7 @@ def test_remote_sequential():
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     lossy_sequential = RemoteSequential(
-        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p, start=True)
     )
 
     test_inputs.grad = None
@@ -58,7 +58,8 @@ def test_remote_sequential():
     assert not torch.allclose(approx_outputs, full_outputs, rtol=0, atol=1e-4), "compression was not used"
     assert not torch.allclose(test_inputs.grad, full_grad, rtol=0, atol=1e-2), "compression was not used"
     assert abs(approx_outputs - full_outputs).mean() < 0.01
-    assert abs(test_inputs.grad - full_grad).mean() < 0.3
+    absmax = abs(full_grad).max()
+    assert abs(test_inputs.grad / absmax - full_grad / absmax).mean() < 0.01
 
 
 class DummyCustomSequenceManager(RemoteSequenceManager):
@@ -73,13 +74,12 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
         return rpc_info
 
     def get_request_metadata(self, protocol: str, *args, **kwargs):
+        metadata = super().get_request_metadata(protocol, *args, **kwargs)
         if protocol == "rpc_forward":
-            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.FLOAT16,)))
+            metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
         elif protocol == "rpc_backward":
-            return MSGPackSerializer.dumps(dict(output_compression=(runtime_pb2.CompressionType.BLOCKWISE_8BIT,)))
-        else:
-            assert protocol == "rpc_inference"
-            return super().get_request_metadata(protocol, *args, **kwargs)
+            metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
+        return metadata
 
 
 @pytest.mark.forked

+ 54 - 0
tests/test_sequence_manager.py

@@ -0,0 +1,54 @@
+import threading
+import time
+
+import pytest
+import torch
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+from test_utils import *
+
+from petals.client import RemoteSequenceManager, RemoteSequential
+from petals.client.remote_model import DistributedBloomConfig
+from petals.data_structures import UID_DELIMITER
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@pytest.mark.forked
+def test_sequence_manager_shutdown():
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    sequential = RemoteSequential(config, dht)
+    shutdown_evt = threading.Event()
+
+    # test RemoteSequential with lossy compression
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    sequential = RemoteSequential(
+        config,
+        dht,
+        sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt, start=True),
+    )
+
+    assert sequential.sequence_manager.is_alive()
+    assert sequential.sequence_manager._thread.ready.is_set()
+    assert not shutdown_evt.is_set()
+    sequential(torch.randn(1, 2, config.hidden_size))
+
+    sequential.sequence_manager.shutdown()
+    del sequential
+    time.sleep(1)
+
+    assert shutdown_evt.is_set()
+
+
+class TestSequenceManager(RemoteSequenceManager):
+    """A sequence manager that signals if it was shut down"""
+
+    def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._was_shut_down = _was_shut_down
+
+    def shutdown(self):
+        super().shutdown()
+        assert not self.is_alive()
+        self._was_shut_down.set()