Просмотр исходного кода

Refactor RemoteSequenceManager (#309)

This PR:

1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**

    The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.

2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**

    `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.

3. **Simplifies retry logic.**

    Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.

4. **Removes deprecated `RemoteTransformerBlock`.**

    `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.

5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**

    This functions duplicate the functionality of the `RemoteSequential` constructor.

6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**

    This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
Alexander Borzunov 2 лет назад
Родитель
Сommit
8f6342a861

+ 2 - 0
README.md

@@ -111,6 +111,8 @@ See the instructions for macOS and Windows, the full requirements, and troublesh
 
 ## Benchmarks
 
+The benchmarks below are for BLOOM-176B:
+
 <table align="center">
   <tr>
     <th colspan="2">Network</th>

+ 1 - 1
src/petals/client/__init__.py

@@ -5,6 +5,6 @@ from petals.client.remote_model import (
     DistributedBloomForSequenceClassification,
     DistributedBloomModel,
 )
-from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from petals.client.remote_sequential import RemoteSequential
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 5 - 8
src/petals/client/inference_session.py

@@ -8,7 +8,6 @@ from typing import AsyncIterator, List, Optional
 
 import torch
 from hivemind import (
-    P2P,
     MSGPackSerializer,
     anext,
     deserialize_torch_tensor,
@@ -162,9 +161,8 @@ class InferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int):
+    def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
         self._sequence_manager = sequence_manager
-        self._p2p = p2p
         self._closed = False
         self._chosen_spans = []
         self._server_sessions = []
@@ -181,7 +179,7 @@ class InferenceSession:
         server_sessions = []
         try:
             for span in chosen_spans:
-                stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id)
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 session = RemoteExpertWorker.run_coroutine(
@@ -189,7 +187,7 @@ class InferenceSession:
                         stub,
                         span_uids,
                         rpc_info=self._sequence_manager.rpc_info,
-                        timeout=self._sequence_manager.request_timeout,
+                        timeout=self._sequence_manager.config.request_timeout,
                         max_length=self._max_length,
                         **metadata,
                     )
@@ -305,9 +303,8 @@ class InferenceSession:
                     self._sequence_manager.on_request_success(span.peer_id)
                     break
                 except Exception as e:
-                    if span is not None:
-                        self._sequence_manager.on_request_failure(span.peer_id)
-                    if attempt_no + 1 == self._sequence_manager.max_retries:
+                    self._sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                    if attempt_no + 1 == self._sequence_manager.config.max_retries:
                         raise
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(

+ 6 - 25
src/petals/client/remote_model.py

@@ -18,13 +18,14 @@ from transformers.models.bloom import (
 from petals.bloom.modeling_utils import LMHead
 from petals.client.remote_generation import RemoteGenerationMixin
 from petals.client.remote_sequential import RemoteSequential
+from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.utils.misc import DUMMY
 
 logger = get_logger(__name__)
 
 
-class DistributedBloomConfig(BloomConfig):
+class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
     """
     A bloom config that contains information about DHT peers.
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
@@ -33,15 +34,9 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
-    dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    request_timeout: int = 3 * 60  # a number of seconds for waiting result from each node
-    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
-    allowed_servers: Optional[
-        Collection[Union[str, hivemind.PeerID]]
-    ] = None  # if defined, send requests only to these servers
 
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
-    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
 
     # This settings matter for running the client with dtype bfloat16 on CPU.
     # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
@@ -106,30 +101,16 @@ class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
 
     config_class = DistributedBloomConfig
 
-    def __init__(self, config: DistributedBloomConfig):
+    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
 
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
         assert len(self.h) == 0
         config.n_layer = n_layer
 
-        dht = config.dht
-        if dht is None:
-            dht = hivemind.DHT(
-                initial_peers=config.initial_peers,
-                client_mode=True,
-                num_workers=n_layer,
-                startup_timeout=config.daemon_startup_timeout,
-                start=True,
-            )
-        assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
-        self.h = RemoteSequential(
-            config,
-            dht,
-            config.dht_prefix,
-        )
+        self.h = RemoteSequential(config, dht=dht)
 
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)

+ 20 - 61
src/petals/client/remote_sequential.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 from typing import Optional, Union
 
 import torch
-from hivemind import DHT, P2P, get_logger
+from hivemind import DHT, get_logger
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 
@@ -25,39 +25,26 @@ class RemoteSequential(nn.Module):
     def __init__(
         self,
         config: petals.client.DistributedBloomConfig,
-        dht: DHT,
-        dht_prefix: Optional[str] = None,
-        p2p: Optional[P2P] = None,
+        *,
         sequence_manager: Optional[RemoteSequenceManager] = None,
-        **kwargs,
+        dht: Optional[DHT] = None,
+        start_block: Optional[int] = None,
+        end_block: Optional[int] = None,
     ):
         super().__init__()
         self.config = config
-        self.dht = dht
-        self.dht_prefix = dht_prefix or config.dht_prefix
-        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
-        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
-        block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
+        assert sequence_manager is None or (
+            dht is None and start_block is None and end_block is None
+        ), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`"
         if sequence_manager is None:
-            logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(
-                dht,
-                block_uids,
-                self.p2p,
-                request_timeout=config.request_timeout,
-                max_retries=config.max_retries,
-                allowed_servers=config.allowed_servers,
-                **kwargs,
-            )
-            self.is_subsequence = False
-        else:
-            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
-            if kwargs:
-                logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided")
-            self.sequence_manager = sequence_manager
-            assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
-            self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
+            if start_block is None:
+                start_block = 0
+            if end_block is None:
+                end_block = self.config.n_layer
+            block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
+            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
+        self.sequence_manager = sequence_manager
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
@@ -66,23 +53,10 @@ class RemoteSequential(nn.Module):
         return outputs
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
-        assert isinstance(ix, (int, slice))
-        if isinstance(ix, int):
-            return RemoteTransformerBlock(
-                self.config,
-                self.dht,
-                dht_prefix=self.dht_prefix,
-                p2p=self.p2p,
-                sequence_manager=self.sequence_manager[ix],
-            )
-        else:
-            return RemoteSequential(
-                self.config,
-                self.dht,
-                dht_prefix=self.dht_prefix,
-                p2p=self.p2p,
-                sequence_manager=self.sequence_manager[ix],
-            )
+        return RemoteSequential(
+            self.config,
+            sequence_manager=self.sequence_manager[ix],
+        )
 
     def __iter__(self):
         for block_index in range(len(self)):
@@ -92,22 +66,7 @@ class RemoteSequential(nn.Module):
         return len(self.sequence_manager)
 
     def inference_session(self, **kwargs) -> InferenceSession:
-        return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
+        return InferenceSession(self.sequence_manager, **kwargs)
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
-
-
-class RemoteTransformerBlock(RemoteSequential):
-    """Single transformer block hosted by swarm
-
-    This class is deprecated and kept for backward compatibility.
-    It will be removed soon in favor of using ``RemoteSequential`` directly.
-    """
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        assert len(self) == 1, "Remote Block is a sequence size 1"
-
-    def extra_repr(self):
-        return f"{self.sequence_manager.block_uids[0]}"

+ 2 - 2
src/petals/client/routing/sequence_info.py

@@ -27,14 +27,14 @@ class RemoteSequenceInfo:
     block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated
     spans_by_priority: List[RemoteSpanInfo]
     spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
-    last_updated_time: float
+    last_updated_time: Optional[float]
 
     @classmethod
     def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
         block_uids = tuple(block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_spans = tuple([] for _ in range(len(block_uids)))
-        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
+        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)
 
     def __getitem__(self, ix: slice):
         assert isinstance(ix, slice)

+ 146 - 153
src/petals/client/routing/sequence_manager.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 
 import asyncio
+import dataclasses
 import itertools
 import logging
 import random
@@ -13,7 +14,6 @@ import numpy as np
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2PHandlerError
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 
@@ -26,6 +26,33 @@ from petals.server.handler import TransformerConnectionHandler
 logger = get_logger(__name__)
 
 
+@dataclasses.dataclass
+class SequenceManagerConfig:
+    allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+
+    request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
+    update_period: float = 60  # refresh DHT information once in this many seconds
+
+    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
+    max_backoff: float = 60  # limit maximal sleep time between retries to this value
+    ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
+
+
+@dataclasses.dataclass
+class SequenceManagerState:
+    p2p: P2P = None
+    sequence_info: Optional[RemoteSequenceInfo] = None
+    rpc_info: Optional[dict] = None
+    banned_peers: Optional[Blacklist] = None
+
+    def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
+        return dataclasses.replace(self, sequence_info=self.sequence_info[ix])
+
+    def __len__(self) -> int:
+        return len(self.sequence_info)
+
+
 class RemoteSequenceManager:
     """
     Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
@@ -34,67 +61,56 @@ class RemoteSequenceManager:
     Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
     To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
 
-    :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
-    :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
-    :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
-    :param update_period: by default, refresh DHT information once in this many seconds
-    :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
-    :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
-    :param max_backoff: limit maximal sleep time between retries to this value
-    :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
-    :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
-    :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
-    :param allowed_servers: if defined, send requests only to these servers
-    :param start: start the background thread (see the note below). If false, you will need to start it manually.
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
       running redundant sequence managers for the same set of layers.
-
     """
 
     def __init__(
         self,
-        dht: DHT,
+        config: SequenceManagerConfig,
         block_uids: Sequence[ModuleUID],
-        p2p: P2P,
-        update_period: float = 30,
-        request_timeout: float = 30,
-        max_retries: Optional[int] = None,
-        min_backoff: float = 1,
-        max_backoff: float = 15 * 60,
-        ban_timeout: float = 15,
-        sequence_info: Optional[RemoteSequenceInfo] = None,
-        rpc_info: Optional[dict] = None,
-        allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
-        banned_peers: Optional[Blacklist] = None,
-        # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
+        *,
+        dht: Optional[DHT] = None,
+        state: Optional[SequenceManagerState] = None,
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
-        self.dht, self.p2p = dht, p2p
-        self.request_timeout, self.max_retries = request_timeout, max_retries
-        self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
+
+        self.config = config
+        if state is None:
+            state = SequenceManagerState()
+        self.state = state
+
+        if dht is None:
+            dht = DHT(
+                initial_peers=config.initial_peers,
+                client_mode=True,
+                num_workers=config.n_layer,
+                startup_timeout=config.daemon_startup_timeout,
+                start=True,
+            )
+        assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
+        self.dht = dht
+
+        if state.p2p is None:
+            state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+
         self.lock_changes = threading.Lock()
-        self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
+        self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))
         self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
-        self._rpc_info = rpc_info
 
-        if allowed_servers is not None:
-            allowed_servers = {
-                PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
-            }
-        self.allowed_servers = allowed_servers
-        self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
-
-        if sequence_info is None:
-            self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
+        if state.banned_peers is None:
+            state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
+        if state.sequence_info is None:
+            state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
 
+        if state.sequence_info.last_updated_time is None:
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
             petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
             self._need_latest_infos = False
         else:
-            self.sequence_info = sequence_info
-            assert block_uids == sequence_info.block_uids
+            assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
             self._need_latest_infos = True
 
@@ -118,7 +134,7 @@ class RemoteSequenceManager:
         span_sequence = []
         current_index = start_index
         while current_index < end_index:
-            candidate_spans = self.sequence_info.spans_containing_block[current_index]
+            candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
             if mode == "random":
@@ -143,86 +159,62 @@ class RemoteSequenceManager:
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
-        return type(self)(
-            self.dht,
-            self.block_uids[ix],
-            self.p2p,
-            update_period=self._thread.update_period,
-            request_timeout=self.request_timeout,
-            ban_timeout=self.ban_timeout,
-            min_backoff=self.min_backoff,
-            max_backoff=self.max_backoff,
-            sequence_info=self.sequence_info[ix],
-            rpc_info=self._rpc_info,
-            allowed_servers=self.allowed_servers,
-            banned_peers=self.banned_peers,
-        )
+        return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
 
     def update(self, *, wait: bool):
         """Run an asynchronous update in background as soon as possible"""
-        self.ready.clear()  # TODO this should be a separate event
+        self.ready.clear()
         self._thread.trigger.set()
         if wait:
             self.ready.wait()
 
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
-        for attempt_no in itertools.count():
-            try:
-                new_block_infos = petals.dht_utils.get_remote_module_infos(
-                    self.dht, self.block_uids, latest=self._need_latest_infos
-                )
-                self._need_latest_infos = True  # All future _update() should use latest infos
-
-                for block_info in new_block_infos:
-                    if not block_info:
-                        continue
-
-                    # Apply whitelist, if defined
-                    if self.allowed_servers is not None:
-                        block_info.servers = {
-                            peer_id: server_info
-                            for peer_id, server_info in block_info.servers.items()
-                            if peer_id in self.allowed_servers
-                        }
-
-                    # Remove temporarily banned peers, unless there are no peers left
-                    valid_servers = {
-                        peer_id: server_info
-                        for peer_id, server_info in block_info.servers.items()
-                        if peer_id not in self.banned_peers
-                    }
-                    if len(valid_servers) < len(block_info.servers):
-                        if valid_servers:
-                            logger.debug(
-                                f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
-                            )
-                            block_info.servers = valid_servers
-                        else:
-                            # If we blacklisted all servers, the error may actually be client-caused
-                            logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
-
-                with self.lock_changes:
-                    self.sequence_info.update_(new_block_infos)
-                missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
-                if missing_blocks:
-                    raise MissingBlocksError(missing_blocks)
-                self.ready.set()  # if there is an active server for every block, we may begin running
-                break
+        new_block_infos = petals.dht_utils.get_remote_module_infos(
+            self.dht, self.block_uids, latest=self._need_latest_infos
+        )
+        self._need_latest_infos = True  # All future _update() should use latest infos
+
+        for block_info in new_block_infos:
+            if not block_info:
+                continue
+
+            # Apply whitelist, if defined
+            if self.config.allowed_servers is not None:
+                block_info.servers = {
+                    peer_id: server_info
+                    for peer_id, server_info in block_info.servers.items()
+                    if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
+                }
+
+            # Remove temporarily banned peers, unless there are no peers left
+            valid_servers = {
+                peer_id: server_info
+                for peer_id, server_info in block_info.servers.items()
+                if peer_id not in self.state.banned_peers
+            }
+            if len(valid_servers) < len(block_info.servers):
+                if valid_servers:
+                    logger.debug(
+                        f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
+                    )
+                    block_info.servers = valid_servers
+                else:
+                    # If we blacklisted all servers, the error may actually be client-caused
+                    logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
 
-            except Exception as e:
-                delay = self.get_retry_delay(attempt_no)
-                logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
-                maybe_log_traceback(e)
-                time.sleep(delay)
+        with self.lock_changes:
+            self.state.sequence_info.update_(new_block_infos)
+        self.ready.set()
 
-    def on_request_failure(self, peer_id: PeerID):
+    def on_request_failure(self, peer_id: Optional[PeerID]):
         """remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
-        logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
-        self.banned_peers.register_failure(peer_id)
+        if peer_id is not None:
+            logger.debug(f"Peer {peer_id} did not respond, banning it temporarily")
+            self.state.banned_peers.register_failure(peer_id)
         with self.lock_changes:
             should_update = False
-            for info in self.sequence_info.block_infos:
+            for info in self.state.sequence_info.block_infos:
                 info.servers.pop(peer_id, None)
                 if not info.servers:
                     should_update = True
@@ -232,7 +224,7 @@ class RemoteSequenceManager:
 
     def on_request_success(self, peer_id: PeerID):
         """if peer has a failure streak, clear that streak"""
-        self.banned_peers.register_success(peer_id)
+        self.state.banned_peers.register_success(peer_id)
 
     def __len__(self):
         return len(self.block_uids)
@@ -247,57 +239,58 @@ class RemoteSequenceManager:
 
     @property
     def block_uids(self):
-        return self.sequence_info.block_uids
+        return self.state.sequence_info.block_uids
 
     @property
     def rpc_info(self):
         """Return the rpc_info queried from one of the servers that hold the first block"""
-        if self._rpc_info is None:
-            with self._thread_start_lock:
-                if not self.is_alive():
-                    self._thread.start()
-
-            for attempt_no in itertools.count():
-                peer_id = None
-                try:
-                    if not self.ready.is_set():
-                        self.update(wait=True)
-
-                    active_servers = [
-                        peer_id
-                        for peer_id, server in self.sequence_info.block_infos[0].servers.items()
-                        if server.state == ServerState.ONLINE
-                    ]
-                    if not active_servers:
-                        raise MissingBlocksError(0)
-                    peer_id = random.choice(active_servers)
-
-                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
-                    outputs = RemoteExpertWorker.run_coroutine(
-                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
-                    )
-                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
-                    self.on_request_success(peer_id)
-                    break
-                except Exception as e:
-                    if peer_id is not None and not isinstance(e, P2PHandlerError):
-                        self.on_request_failure(peer_id)
-                    if attempt_no + 1 == self.max_retries:
-                        raise
-                    delay = self.get_retry_delay(attempt_no)
-                    logger.warning(
-                        f"Caught exception when gathering information from peer {peer_id} "
-                        f"(retry in {delay:.0f} sec): {repr(e)}"
-                    )
-                    maybe_log_traceback(e)
-                    time.sleep(delay)
+        if self.state.rpc_info is not None:
+            return self.state.rpc_info
+
+        with self._thread_start_lock:
+            if not self.is_alive():
+                self._thread.start()
+
+        for attempt_no in itertools.count():
+            peer_id = None
+            try:
+                if not self.ready.is_set():
+                    self.update(wait=True)
+
+                active_servers = [
+                    peer_id
+                    for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()
+                    if server.state == ServerState.ONLINE
+                ]
+                if not active_servers:
+                    raise MissingBlocksError(0)
+                peer_id = random.choice(active_servers)
+
+                stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)
+                outputs = RemoteExpertWorker.run_coroutine(
+                    stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)
+                )
+                self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                self.on_request_success(peer_id)
+                break
+            except Exception as e:
+                self.on_request_failure(peer_id)
+                if attempt_no + 1 == self.config.max_retries:
+                    raise
+                delay = self.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when gathering information from peer {peer_id} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                maybe_log_traceback(e)
+                time.sleep(delay)
 
-        return self._rpc_info
+        return self.state.rpc_info
 
     def get_retry_delay(self, attempt_no: int) -> float:
         if attempt_no == 0:
             return 0
-        return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff)
+        return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
 
     def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
         """

+ 8 - 10
src/petals/client/sequential_autograd.py

@@ -67,7 +67,7 @@ async def sequential_forward(
 
                 span = sequences.popleft()
 
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
@@ -77,7 +77,7 @@ async def sequential_forward(
                     stub,
                     sequence_manager.rpc_info,
                     *inputs_and_prompts,
-                    timeout=sequence_manager.request_timeout,
+                    timeout=sequence_manager.config.request_timeout,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
 
@@ -93,9 +93,8 @@ async def sequential_forward(
                 sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
-                if span is not None:
-                    sequence_manager.on_request_failure(span.peer_id)
-                if attempt_no + 1 == sequence_manager.max_retries:
+                sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                if attempt_no + 1 == sequence_manager.config.max_retries:
                     raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
@@ -152,7 +151,7 @@ async def sequential_backward(
                     span = forward_sequences.pop()
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 metadata = sequence_manager.get_request_metadata(
                     "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
                 )
@@ -163,7 +162,7 @@ async def sequential_backward(
                     inputs,
                     grad_outputs,
                     prompts[span.start : span.end],
-                    timeout=sequence_manager.request_timeout,
+                    timeout=sequence_manager.config.request_timeout,
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
                 grad_outputs = [grad_outputs]
@@ -171,9 +170,8 @@ async def sequential_backward(
                 sequence_manager.on_request_success(span.peer_id)
                 break
             except Exception as e:
-                if span is not None:
-                    sequence_manager.on_request_failure(span.peer_id)
-                if attempt_no + 1 == sequence_manager.max_retries:
+                sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                if attempt_no + 1 == sequence_manager.config.max_retries:
                     raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(

+ 0 - 61
src/petals/dht_utils.py

@@ -71,67 +71,6 @@ async def _declare_active_modules(
     )
 
 
-def get_remote_sequence(
-    dht: DHT,
-    start: int,
-    stop: int,
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-    return_future: bool = False,
-) -> Union[petals.client.RemoteSequential, MPFuture]:
-    return RemoteExpertWorker.run_coroutine(
-        _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
-    )
-
-
-async def _get_remote_sequence(
-    dht: DHT,
-    start: int,
-    stop: int,
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-) -> petals.client.RemoteSequential:
-    uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
-    p2p = await dht.replicate_p2p()
-    manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
-    return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
-
-
-def get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-    return_future: bool = False,
-) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]:
-    """
-    :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param config: model config, usually taken by .from_pretrained(MODEL_NAME)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock]
-    """
-    return RemoteExpertWorker.run_coroutine(
-        _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
-    )
-
-
-async def _get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]:
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    p2p = await dht.replicate_p2p()
-    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
-    modules = [
-        petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
-        for m in managers
-    ]
-    return modules[0] if single_uid else modules
-
-
 def get_remote_module_infos(
     dht: DHT,
     uids: Sequence[ModuleUID],

+ 4 - 9
tests/test_block_exact_match.py

@@ -1,28 +1,24 @@
 import random
 from typing import Union
 
-import hivemind
 import pytest
 import torch
 from transformers.models.bloom.configuration_bloom import BloomConfig
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
-from petals.client import DistributedBloomConfig
-from petals.client.remote_sequential import RemoteTransformerBlock
+from petals.client import DistributedBloomConfig, RemoteSequential
 from petals.data_structures import UID_DELIMITER
-from petals.dht_utils import get_remote_module
 from test_utils import *
 
 
 @pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_sequential = RemoteSequential(config)
 
     for block_index in random.sample(range(config.n_layer), 3):
-        remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
-        assert isinstance(remote_block, RemoteTransformerBlock)
+        remote_block = remote_sequential[block_index]
 
         inputs = torch.randn(1, 8, config.hidden_size)
         outputs_forward = remote_block(inputs)
@@ -36,7 +32,6 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
             with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
                 sess.step(inputs[:, -1:, :])
             assert "Maximum length exceeded" in repr(exc_info.value)
-
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

+ 4 - 9
tests/test_chained_calls.py

@@ -4,22 +4,19 @@
 # - if you want to figure out chained inference, ask yozh
 
 
-import hivemind
 import pytest
 import torch
 
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteSequential
-from petals.dht_utils import get_remote_sequence
 from test_utils import *
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_blocks = get_remote_sequence(dht, 3, 6, config)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
     assert isinstance(remote_blocks, RemoteSequential)
 
     ref_blocks = [
@@ -46,10 +43,8 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_blocks = get_remote_sequence(dht, 3, 5, config)
-    assert isinstance(remote_blocks, RemoteSequential)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
 
     inputs = torch.randn(1, 8, config.hidden_size)
 

+ 3 - 4
tests/test_remote_sequential.py

@@ -20,7 +20,7 @@ def test_remote_sequential():
     test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
     grad_proj = torch.randn(1, 5, config.hidden_size)
 
-    sequential = RemoteSequential(config, dht)
+    sequential = RemoteSequential(config, dht=dht)
 
     full_outputs = sequential(test_inputs)
     (full_outputs * grad_proj).sum().backward()
@@ -48,7 +48,7 @@ def test_remote_sequential():
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     lossy_sequential = RemoteSequential(
-        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+        config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
     )
 
     test_inputs.grad = None
@@ -85,8 +85,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
 @pytest.mark.forked
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
-    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
-    remote_sequential = RemoteSequential(config, dht)
+    remote_sequential = RemoteSequential(config)
 
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)

+ 2 - 3
tests/test_sequence_manager.py

@@ -18,15 +18,14 @@ logger = get_logger(__name__)
 def test_sequence_manager_basics(mode: str):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
-    sequential = RemoteSequential(config, dht)
+    sequential = RemoteSequential(config, dht=dht)
     shutdown_evt = threading.Event()
 
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     sequential = RemoteSequential(
         config,
-        dht,
-        sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt),
+        sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
     )
 
     sequence = sequential.sequence_manager.make_sequence(mode=mode)

+ 7 - 8
tests/test_server_stats.py

@@ -4,34 +4,33 @@ import hivemind
 import pytest
 import torch
 
-from petals.client import DistributedBloomConfig
+from petals.client import DistributedBloomConfig, RemoteSequential
 from petals.data_structures import UID_DELIMITER
-from petals.dht_utils import get_remote_sequence
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 
 
 @pytest.mark.forked
 def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
+    blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
 
-    blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
-    blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
     info_before = blocks1.sequence_manager.rpc_info
 
     with blocks1.inference_session(max_length=max_length) as sess:
         sess.step(torch.randn(1, 1, config.hidden_size))
-        blocks1.sequence_manager._rpc_info = None  # invalidate cache
+        blocks1.sequence_manager.state.rpc_info = None  # invalidate cache
         info_inside = blocks1.sequence_manager.rpc_info
 
         with blocks2.inference_session(max_length=max_length2) as sess2:
             sess2.step(torch.randn(1, 1, config.hidden_size))
-            blocks2.sequence_manager._rpc_info = None  # invalidate cache
+            blocks2.sequence_manager.state.rpc_info = None  # invalidate cache
             info_inside2 = blocks2.sequence_manager.rpc_info
 
     time.sleep(0.1)
-    blocks1.sequence_manager._rpc_info = None  # invalidate cache
+    blocks1.sequence_manager.state.rpc_info = None  # invalidate cache
     info_after = blocks1.sequence_manager.rpc_info
 
     assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]