ソースを参照

Refactor RemoteSequenceManager (#309)

This PR:

1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**

    The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.

2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**

    `dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.

3. **Simplifies retry logic.**

    Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.

4. **Removes deprecated `RemoteTransformerBlock`.**

    `RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.

5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**

    This functions duplicate the functionality of the `RemoteSequential` constructor.

6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**

    This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
Alexander Borzunov 2 年 前
コミット
8f6342a861

+ 2 - 0
README.md

@@ -111,6 +111,8 @@ See the instructions for macOS and Windows, the full requirements, and troublesh
 
 
 ## Benchmarks
 ## Benchmarks
 
 
+The benchmarks below are for BLOOM-176B:
+
 <table align="center">
 <table align="center">
   <tr>
   <tr>
     <th colspan="2">Network</th>
     <th colspan="2">Network</th>

+ 1 - 1
src/petals/client/__init__.py

@@ -5,6 +5,6 @@ from petals.client.remote_model import (
     DistributedBloomForSequenceClassification,
     DistributedBloomForSequenceClassification,
     DistributedBloomModel,
     DistributedBloomModel,
 )
 )
-from petals.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
+from petals.client.remote_sequential import RemoteSequential
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 5 - 8
src/petals/client/inference_session.py

@@ -8,7 +8,6 @@ from typing import AsyncIterator, List, Optional
 
 
 import torch
 import torch
 from hivemind import (
 from hivemind import (
-    P2P,
     MSGPackSerializer,
     MSGPackSerializer,
     anext,
     anext,
     deserialize_torch_tensor,
     deserialize_torch_tensor,
@@ -162,9 +161,8 @@ class InferenceSession:
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     An interface to a multi-step *inference* session for a sequence of remote transformer blocks
     """
     """
 
 
-    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int):
+    def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
         self._sequence_manager = sequence_manager
         self._sequence_manager = sequence_manager
-        self._p2p = p2p
         self._closed = False
         self._closed = False
         self._chosen_spans = []
         self._chosen_spans = []
         self._server_sessions = []
         self._server_sessions = []
@@ -181,7 +179,7 @@ class InferenceSession:
         server_sessions = []
         server_sessions = []
         try:
         try:
             for span in chosen_spans:
             for span in chosen_spans:
-                stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(self._sequence_manager.state.p2p, span.peer_id)
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
                 session = RemoteExpertWorker.run_coroutine(
                 session = RemoteExpertWorker.run_coroutine(
@@ -189,7 +187,7 @@ class InferenceSession:
                         stub,
                         stub,
                         span_uids,
                         span_uids,
                         rpc_info=self._sequence_manager.rpc_info,
                         rpc_info=self._sequence_manager.rpc_info,
-                        timeout=self._sequence_manager.request_timeout,
+                        timeout=self._sequence_manager.config.request_timeout,
                         max_length=self._max_length,
                         max_length=self._max_length,
                         **metadata,
                         **metadata,
                     )
                     )
@@ -305,9 +303,8 @@ class InferenceSession:
                     self._sequence_manager.on_request_success(span.peer_id)
                     self._sequence_manager.on_request_success(span.peer_id)
                     break
                     break
                 except Exception as e:
                 except Exception as e:
-                    if span is not None:
-                        self._sequence_manager.on_request_failure(span.peer_id)
-                    if attempt_no + 1 == self._sequence_manager.max_retries:
+                    self._sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                    if attempt_no + 1 == self._sequence_manager.config.max_retries:
                         raise
                         raise
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     delay = self._sequence_manager.get_retry_delay(attempt_no)
                     logger.warning(
                     logger.warning(

+ 6 - 25
src/petals/client/remote_model.py

@@ -18,13 +18,14 @@ from transformers.models.bloom import (
 from petals.bloom.modeling_utils import LMHead
 from petals.bloom.modeling_utils import LMHead
 from petals.client.remote_generation import RemoteGenerationMixin
 from petals.client.remote_generation import RemoteGenerationMixin
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.remote_sequential import RemoteSequential
+from petals.client.routing.sequence_manager import SequenceManagerConfig
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.utils.misc import DUMMY
 from petals.utils.misc import DUMMY
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class DistributedBloomConfig(BloomConfig):
+class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
     """
     """
     A bloom config that contains information about DHT peers.
     A bloom config that contains information about DHT peers.
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
     To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
@@ -33,15 +34,9 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
     initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
     daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
-    dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    request_timeout: int = 3 * 60  # a number of seconds for waiting result from each node
-    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
-    allowed_servers: Optional[
-        Collection[Union[str, hivemind.PeerID]]
-    ] = None  # if defined, send requests only to these servers
 
 
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
     pre_seq_len: int = 0  # a number of tokens for prompt tuning.
-    tuning_mode: Optional[str] = None  # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
+    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
 
 
     # This settings matter for running the client with dtype bfloat16 on CPU.
     # This settings matter for running the client with dtype bfloat16 on CPU.
     # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
     # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
@@ -106,30 +101,16 @@ class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
 
 
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
-    def __init__(self, config: DistributedBloomConfig):
+    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
         assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
+        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
 
 
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
         super().__init__(config)
         super().__init__(config)
         assert len(self.h) == 0
         assert len(self.h) == 0
         config.n_layer = n_layer
         config.n_layer = n_layer
 
 
-        dht = config.dht
-        if dht is None:
-            dht = hivemind.DHT(
-                initial_peers=config.initial_peers,
-                client_mode=True,
-                num_workers=n_layer,
-                startup_timeout=config.daemon_startup_timeout,
-                start=True,
-            )
-        assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
-        self.h = RemoteSequential(
-            config,
-            dht,
-            config.dht_prefix,
-        )
+        self.h = RemoteSequential(config, dht=dht)
 
 
         # Forbid accumulate grads for embeddings and layernorm
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
         self.set_requires_grad(False)

+ 20 - 61
src/petals/client/remote_sequential.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import torch
 import torch
-from hivemind import DHT, P2P, get_logger
+from hivemind import DHT, get_logger
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from torch import nn
 from torch import nn
 
 
@@ -25,39 +25,26 @@ class RemoteSequential(nn.Module):
     def __init__(
     def __init__(
         self,
         self,
         config: petals.client.DistributedBloomConfig,
         config: petals.client.DistributedBloomConfig,
-        dht: DHT,
-        dht_prefix: Optional[str] = None,
-        p2p: Optional[P2P] = None,
+        *,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
-        **kwargs,
+        dht: Optional[DHT] = None,
+        start_block: Optional[int] = None,
+        end_block: Optional[int] = None,
     ):
     ):
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
-        self.dht = dht
-        self.dht_prefix = dht_prefix or config.dht_prefix
-        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
 
-        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
-        block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks))
+        assert sequence_manager is None or (
+            dht is None and start_block is None and end_block is None
+        ), "`dht`, `start_block`, and `end_block` have no effect when you provide a custom `sequence_manager`"
         if sequence_manager is None:
         if sequence_manager is None:
-            logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(
-                dht,
-                block_uids,
-                self.p2p,
-                request_timeout=config.request_timeout,
-                max_retries=config.max_retries,
-                allowed_servers=config.allowed_servers,
-                **kwargs,
-            )
-            self.is_subsequence = False
-        else:
-            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
-            if kwargs:
-                logger.warning(f"Parameters {kwargs} are ignored because sequence_manager is explicitly provided")
-            self.sequence_manager = sequence_manager
-            assert isinstance(sequence_manager.sequence_info.block_uids, tuple)
-            self.is_subsequence = self.sequence_manager.sequence_info.block_uids != block_uids
+            if start_block is None:
+                start_block = 0
+            if end_block is None:
+                end_block = self.config.n_layer
+            block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
+            sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
+        self.sequence_manager = sequence_manager
 
 
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
     def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
@@ -66,23 +53,10 @@ class RemoteSequential(nn.Module):
         return outputs
         return outputs
 
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
-        assert isinstance(ix, (int, slice))
-        if isinstance(ix, int):
-            return RemoteTransformerBlock(
-                self.config,
-                self.dht,
-                dht_prefix=self.dht_prefix,
-                p2p=self.p2p,
-                sequence_manager=self.sequence_manager[ix],
-            )
-        else:
-            return RemoteSequential(
-                self.config,
-                self.dht,
-                dht_prefix=self.dht_prefix,
-                p2p=self.p2p,
-                sequence_manager=self.sequence_manager[ix],
-            )
+        return RemoteSequential(
+            self.config,
+            sequence_manager=self.sequence_manager[ix],
+        )
 
 
     def __iter__(self):
     def __iter__(self):
         for block_index in range(len(self)):
         for block_index in range(len(self)):
@@ -92,22 +66,7 @@ class RemoteSequential(nn.Module):
         return len(self.sequence_manager)
         return len(self.sequence_manager)
 
 
     def inference_session(self, **kwargs) -> InferenceSession:
     def inference_session(self, **kwargs) -> InferenceSession:
-        return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
+        return InferenceSession(self.sequence_manager, **kwargs)
 
 
     def extra_repr(self) -> str:
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
-
-
-class RemoteTransformerBlock(RemoteSequential):
-    """Single transformer block hosted by swarm
-
-    This class is deprecated and kept for backward compatibility.
-    It will be removed soon in favor of using ``RemoteSequential`` directly.
-    """
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        assert len(self) == 1, "Remote Block is a sequence size 1"
-
-    def extra_repr(self):
-        return f"{self.sequence_manager.block_uids[0]}"

+ 2 - 2
src/petals/client/routing/sequence_info.py

@@ -27,14 +27,14 @@ class RemoteSequenceInfo:
     block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated
     block_infos: Tuple[RemoteModuleInfo, ...]  # note: the contents of RemoteModuleInfo can and will be updated
     spans_by_priority: List[RemoteSpanInfo]
     spans_by_priority: List[RemoteSpanInfo]
     spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
     spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
-    last_updated_time: float
+    last_updated_time: Optional[float]
 
 
     @classmethod
     @classmethod
     def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
     def make_empty(cls: Type[T], block_uids: Iterable[ModuleUID]) -> T:
         block_uids = tuple(block_uids)
         block_uids = tuple(block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_block_infos = tuple(RemoteModuleInfo(uid, {}) for uid in block_uids)
         empty_spans = tuple([] for _ in range(len(block_uids)))
         empty_spans = tuple([] for _ in range(len(block_uids)))
-        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=-float("inf"))
+        return cls(block_uids, empty_block_infos, [], empty_spans, last_updated_time=None)
 
 
     def __getitem__(self, ix: slice):
     def __getitem__(self, ix: slice):
         assert isinstance(ix, slice)
         assert isinstance(ix, slice)

+ 146 - 153
src/petals/client/routing/sequence_manager.py

@@ -1,6 +1,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import dataclasses
 import itertools
 import itertools
 import logging
 import logging
 import random
 import random
@@ -13,7 +14,6 @@ import numpy as np
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
 from hivemind import DHT, P2P, MSGPackSerializer, PeerID, get_dht_time
 from hivemind.dht.node import Blacklist
 from hivemind.dht.node import Blacklist
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2PHandlerError
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -26,6 +26,33 @@ from petals.server.handler import TransformerConnectionHandler
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+@dataclasses.dataclass
+class SequenceManagerConfig:
+    allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
+
+    request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
+    update_period: float = 60  # refresh DHT information once in this many seconds
+
+    max_retries: Optional[int] = None  # max number retries before the client raises an exception (default: inf)
+    min_backoff: float = 1  # after a repeated failure, sleep for this many seconds times 2 ** (num_failures - 1)
+    max_backoff: float = 60  # limit maximal sleep time between retries to this value
+    ban_timeout: float = 15  # when a remote peer fails to respond, prevent routing to that peer for this many seconds
+
+
+@dataclasses.dataclass
+class SequenceManagerState:
+    p2p: P2P = None
+    sequence_info: Optional[RemoteSequenceInfo] = None
+    rpc_info: Optional[dict] = None
+    banned_peers: Optional[Blacklist] = None
+
+    def __getitem__(self, ix: Union[int, slice]) -> SequenceManagerState:
+        return dataclasses.replace(self, sequence_info=self.sequence_info[ix])
+
+    def __len__(self) -> int:
+        return len(self.sequence_info)
+
+
 class RemoteSequenceManager:
 class RemoteSequenceManager:
     """
     """
     Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
     Sequence manager is a thread that keeps track of remote servers that hold the specified sequence of blocks.
@@ -34,67 +61,56 @@ class RemoteSequenceManager:
     Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
     Using this information, sequence manager can form sequences of servers that collectively have the full sequence.
     To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
     To form such a sequence, call .make_sequence with the appropriate optimization policy (see make_sequence docstr).
 
 
-    :param dht: a running hivemind.DHT instance, connected to peers that serve the corresponding blocks
-    :param block_uids: a sequence of DHT keys (strings) corresponding to remote layers
-    :param p2p: an optional P2P replica (if not specified, create one via dht.replicate_p2p())
-    :param update_period: by default, refresh DHT information once in this many seconds
-    :param request_timeout: float, in seconds, default timeout for RPC forward/backward/inference requests
-    :param min_backoff: after a repeated failure, sleep for this many seconds times 2 ^ (num_failures - 1)
-    :param max_backoff: limit maximal sleep time between retries to this value
-    :param ban_timeout: when a remote peer fails to respond, prevent routing to that peer for this many seconds
-    :param sequence_info: optionally, specify pre-generated sequence info. by default, create a new one using dht
-    :param rpc_info: optionally, specify rpc info (communicated tensor shapes and compression) to save time
-    :param allowed_servers: if defined, send requests only to these servers
-    :param start: start the background thread (see the note below). If false, you will need to start it manually.
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
     :note: RemoteSequenceManager takes up some CPU and network I/O to operate in background. It is recommended to avoid
       running redundant sequence managers for the same set of layers.
       running redundant sequence managers for the same set of layers.
-
     """
     """
 
 
     def __init__(
     def __init__(
         self,
         self,
-        dht: DHT,
+        config: SequenceManagerConfig,
         block_uids: Sequence[ModuleUID],
         block_uids: Sequence[ModuleUID],
-        p2p: P2P,
-        update_period: float = 30,
-        request_timeout: float = 30,
-        max_retries: Optional[int] = None,
-        min_backoff: float = 1,
-        max_backoff: float = 15 * 60,
-        ban_timeout: float = 15,
-        sequence_info: Optional[RemoteSequenceInfo] = None,
-        rpc_info: Optional[dict] = None,
-        allowed_servers: Optional[Collection[Union[str, hivemind.PeerID]]] = None,
-        banned_peers: Optional[Blacklist] = None,
-        # dear dev, if you add more parameters to this class, please make sure to handle them in __getitem__ (below)
+        *,
+        dht: Optional[DHT] = None,
+        state: Optional[SequenceManagerState] = None,
     ):
     ):
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         assert len(block_uids) > 0, "Sequences must contain at least one block"
-        self.dht, self.p2p = dht, p2p
-        self.request_timeout, self.max_retries = request_timeout, max_retries
-        self.ban_timeout, self.min_backoff, self.max_backoff = ban_timeout, min_backoff, max_backoff
+
+        self.config = config
+        if state is None:
+            state = SequenceManagerState()
+        self.state = state
+
+        if dht is None:
+            dht = DHT(
+                initial_peers=config.initial_peers,
+                client_mode=True,
+                num_workers=config.n_layer,
+                startup_timeout=config.daemon_startup_timeout,
+                start=True,
+            )
+        assert isinstance(dht, DHT) and dht.is_alive(), "`dht` must be a running hivemind.DHT instance"
+        self.dht = dht
+
+        if state.p2p is None:
+            state.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
-        self._thread = _SequenceManagerUpdateThread(update_period, WeakMethod(self._update))
+        self._thread = _SequenceManagerUpdateThread(config.update_period, WeakMethod(self._update))
         self._thread_start_lock = threading.Lock()
         self._thread_start_lock = threading.Lock()
         self.policy = NoSpendingPolicy()
         self.policy = NoSpendingPolicy()
-        self._rpc_info = rpc_info
 
 
-        if allowed_servers is not None:
-            allowed_servers = {
-                PeerID.from_base58(peer_id) if isinstance(peer_id, str) else peer_id for peer_id in allowed_servers
-            }
-        self.allowed_servers = allowed_servers
-        self.banned_peers = Blacklist(base_time=ban_timeout, backoff_rate=2.0) if banned_peers is None else banned_peers
-
-        if sequence_info is None:
-            self.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
+        if state.banned_peers is None:
+            state.banned_peers = Blacklist(base_time=config.ban_timeout, backoff_rate=2.0)
+        if state.sequence_info is None:
+            state.sequence_info = RemoteSequenceInfo.make_empty(block_uids)
 
 
+        if state.sequence_info.last_updated_time is None:
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # Pre-fetch module infos in DHT in parallel with .from_pretrained(), then use cached records
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
             # in the first _update() instead of the latest ones. This makes the first .update() faster.
             petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
             petals.dht_utils.get_remote_module_infos(self.dht, self.block_uids, latest=True, return_future=True)
             self._need_latest_infos = False
             self._need_latest_infos = False
         else:
         else:
-            self.sequence_info = sequence_info
-            assert block_uids == sequence_info.block_uids
+            assert block_uids == state.sequence_info.block_uids
             self._thread.ready.set()  # no need to await the first dht fetch
             self._thread.ready.set()  # no need to await the first dht fetch
             self._need_latest_infos = True
             self._need_latest_infos = True
 
 
@@ -118,7 +134,7 @@ class RemoteSequenceManager:
         span_sequence = []
         span_sequence = []
         current_index = start_index
         current_index = start_index
         while current_index < end_index:
         while current_index < end_index:
-            candidate_spans = self.sequence_info.spans_containing_block[current_index]
+            candidate_spans = self.state.sequence_info.spans_containing_block[current_index]
             if not candidate_spans:
             if not candidate_spans:
                 raise MissingBlocksError(current_index)
                 raise MissingBlocksError(current_index)
             if mode == "random":
             if mode == "random":
@@ -143,86 +159,62 @@ class RemoteSequenceManager:
         assert isinstance(ix, (int, slice))
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
             ix = slice(int(ix), int(ix) + 1, 1)
-        return type(self)(
-            self.dht,
-            self.block_uids[ix],
-            self.p2p,
-            update_period=self._thread.update_period,
-            request_timeout=self.request_timeout,
-            ban_timeout=self.ban_timeout,
-            min_backoff=self.min_backoff,
-            max_backoff=self.max_backoff,
-            sequence_info=self.sequence_info[ix],
-            rpc_info=self._rpc_info,
-            allowed_servers=self.allowed_servers,
-            banned_peers=self.banned_peers,
-        )
+        return type(self)(self.config, self.block_uids[ix], dht=self.dht, state=self.state[ix])
 
 
     def update(self, *, wait: bool):
     def update(self, *, wait: bool):
         """Run an asynchronous update in background as soon as possible"""
         """Run an asynchronous update in background as soon as possible"""
-        self.ready.clear()  # TODO this should be a separate event
+        self.ready.clear()
         self._thread.trigger.set()
         self._thread.trigger.set()
         if wait:
         if wait:
             self.ready.wait()
             self.ready.wait()
 
 
     def _update(self):
     def _update(self):
         """Perform an immediate and synchronous refresh, may take time"""
         """Perform an immediate and synchronous refresh, may take time"""
-        for attempt_no in itertools.count():
-            try:
-                new_block_infos = petals.dht_utils.get_remote_module_infos(
-                    self.dht, self.block_uids, latest=self._need_latest_infos
-                )
-                self._need_latest_infos = True  # All future _update() should use latest infos
-
-                for block_info in new_block_infos:
-                    if not block_info:
-                        continue
-
-                    # Apply whitelist, if defined
-                    if self.allowed_servers is not None:
-                        block_info.servers = {
-                            peer_id: server_info
-                            for peer_id, server_info in block_info.servers.items()
-                            if peer_id in self.allowed_servers
-                        }
-
-                    # Remove temporarily banned peers, unless there are no peers left
-                    valid_servers = {
-                        peer_id: server_info
-                        for peer_id, server_info in block_info.servers.items()
-                        if peer_id not in self.banned_peers
-                    }
-                    if len(valid_servers) < len(block_info.servers):
-                        if valid_servers:
-                            logger.debug(
-                                f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
-                            )
-                            block_info.servers = valid_servers
-                        else:
-                            # If we blacklisted all servers, the error may actually be client-caused
-                            logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
-
-                with self.lock_changes:
-                    self.sequence_info.update_(new_block_infos)
-                missing_blocks = [i for i in range(len(self)) if not self.sequence_info.spans_containing_block[i]]
-                if missing_blocks:
-                    raise MissingBlocksError(missing_blocks)
-                self.ready.set()  # if there is an active server for every block, we may begin running
-                break
+        new_block_infos = petals.dht_utils.get_remote_module_infos(
+            self.dht, self.block_uids, latest=self._need_latest_infos
+        )
+        self._need_latest_infos = True  # All future _update() should use latest infos
+
+        for block_info in new_block_infos:
+            if not block_info:
+                continue
+
+            # Apply whitelist, if defined
+            if self.config.allowed_servers is not None:
+                block_info.servers = {
+                    peer_id: server_info
+                    for peer_id, server_info in block_info.servers.items()
+                    if peer_id in self.config.allowed_servers or str(peer_id) in self.config.allowed_servers
+                }
+
+            # Remove temporarily banned peers, unless there are no peers left
+            valid_servers = {
+                peer_id: server_info
+                for peer_id, server_info in block_info.servers.items()
+                if peer_id not in self.state.banned_peers
+            }
+            if len(valid_servers) < len(block_info.servers):
+                if valid_servers:
+                    logger.debug(
+                        f"Kept {len(valid_servers)} out of {len(block_info.servers)} servers holding {block_info.uid}"
+                    )
+                    block_info.servers = valid_servers
+                else:
+                    # If we blacklisted all servers, the error may actually be client-caused
+                    logger.debug(f"All servers holding {block_info.uid} are blacklisted, ignoring blacklist")
 
 
-            except Exception as e:
-                delay = self.get_retry_delay(attempt_no)
-                logger.warning(f"Could not find route through the model: {repr(e)} (retry in {delay:.0f} sec)")
-                maybe_log_traceback(e)
-                time.sleep(delay)
+        with self.lock_changes:
+            self.state.sequence_info.update_(new_block_infos)
+        self.ready.set()
 
 
-    def on_request_failure(self, peer_id: PeerID):
+    def on_request_failure(self, peer_id: Optional[PeerID]):
         """remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
         """remove a given peer from the routing table. If the routing is no longer possible, trigger an update"""
-        logger.info(f"Peer {peer_id} did not respond, banning it temporarily")
-        self.banned_peers.register_failure(peer_id)
+        if peer_id is not None:
+            logger.debug(f"Peer {peer_id} did not respond, banning it temporarily")
+            self.state.banned_peers.register_failure(peer_id)
         with self.lock_changes:
         with self.lock_changes:
             should_update = False
             should_update = False
-            for info in self.sequence_info.block_infos:
+            for info in self.state.sequence_info.block_infos:
                 info.servers.pop(peer_id, None)
                 info.servers.pop(peer_id, None)
                 if not info.servers:
                 if not info.servers:
                     should_update = True
                     should_update = True
@@ -232,7 +224,7 @@ class RemoteSequenceManager:
 
 
     def on_request_success(self, peer_id: PeerID):
     def on_request_success(self, peer_id: PeerID):
         """if peer has a failure streak, clear that streak"""
         """if peer has a failure streak, clear that streak"""
-        self.banned_peers.register_success(peer_id)
+        self.state.banned_peers.register_success(peer_id)
 
 
     def __len__(self):
     def __len__(self):
         return len(self.block_uids)
         return len(self.block_uids)
@@ -247,57 +239,58 @@ class RemoteSequenceManager:
 
 
     @property
     @property
     def block_uids(self):
     def block_uids(self):
-        return self.sequence_info.block_uids
+        return self.state.sequence_info.block_uids
 
 
     @property
     @property
     def rpc_info(self):
     def rpc_info(self):
         """Return the rpc_info queried from one of the servers that hold the first block"""
         """Return the rpc_info queried from one of the servers that hold the first block"""
-        if self._rpc_info is None:
-            with self._thread_start_lock:
-                if not self.is_alive():
-                    self._thread.start()
-
-            for attempt_no in itertools.count():
-                peer_id = None
-                try:
-                    if not self.ready.is_set():
-                        self.update(wait=True)
-
-                    active_servers = [
-                        peer_id
-                        for peer_id, server in self.sequence_info.block_infos[0].servers.items()
-                        if server.state == ServerState.ONLINE
-                    ]
-                    if not active_servers:
-                        raise MissingBlocksError(0)
-                    peer_id = random.choice(active_servers)
-
-                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
-                    outputs = RemoteExpertWorker.run_coroutine(
-                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
-                    )
-                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
-                    self.on_request_success(peer_id)
-                    break
-                except Exception as e:
-                    if peer_id is not None and not isinstance(e, P2PHandlerError):
-                        self.on_request_failure(peer_id)
-                    if attempt_no + 1 == self.max_retries:
-                        raise
-                    delay = self.get_retry_delay(attempt_no)
-                    logger.warning(
-                        f"Caught exception when gathering information from peer {peer_id} "
-                        f"(retry in {delay:.0f} sec): {repr(e)}"
-                    )
-                    maybe_log_traceback(e)
-                    time.sleep(delay)
+        if self.state.rpc_info is not None:
+            return self.state.rpc_info
+
+        with self._thread_start_lock:
+            if not self.is_alive():
+                self._thread.start()
+
+        for attempt_no in itertools.count():
+            peer_id = None
+            try:
+                if not self.ready.is_set():
+                    self.update(wait=True)
+
+                active_servers = [
+                    peer_id
+                    for peer_id, server in self.state.sequence_info.block_infos[0].servers.items()
+                    if server.state == ServerState.ONLINE
+                ]
+                if not active_servers:
+                    raise MissingBlocksError(0)
+                peer_id = random.choice(active_servers)
+
+                stub = TransformerConnectionHandler.get_stub(self.state.p2p, peer_id)
+                outputs = RemoteExpertWorker.run_coroutine(
+                    stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]), timeout=self.config.request_timeout)
+                )
+                self.state.rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                self.on_request_success(peer_id)
+                break
+            except Exception as e:
+                self.on_request_failure(peer_id)
+                if attempt_no + 1 == self.config.max_retries:
+                    raise
+                delay = self.get_retry_delay(attempt_no)
+                logger.warning(
+                    f"Caught exception when gathering information from peer {peer_id} "
+                    f"(retry in {delay:.0f} sec): {repr(e)}"
+                )
+                maybe_log_traceback(e)
+                time.sleep(delay)
 
 
-        return self._rpc_info
+        return self.state.rpc_info
 
 
     def get_retry_delay(self, attempt_no: int) -> float:
     def get_retry_delay(self, attempt_no: int) -> float:
         if attempt_no == 0:
         if attempt_no == 0:
             return 0
             return 0
-        return min(self.min_backoff * 2 ** (attempt_no - 1), self.max_backoff)
+        return min(self.config.min_backoff * 2 ** (attempt_no - 1), self.config.max_backoff)
 
 
     def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
     def get_request_metadata(self, protocol: str, *args, **kwargs) -> Optional[Dict[str, Any]]:
         """
         """

+ 8 - 10
src/petals/client/sequential_autograd.py

@@ -67,7 +67,7 @@ async def sequential_forward(
 
 
                 span = sequences.popleft()
                 span = sequences.popleft()
 
 
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
                 inputs_and_prompts = [inputs, prompts[span.start : span.end]]
 
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
@@ -77,7 +77,7 @@ async def sequential_forward(
                     stub,
                     stub,
                     sequence_manager.rpc_info,
                     sequence_manager.rpc_info,
                     *inputs_and_prompts,
                     *inputs_and_prompts,
-                    timeout=sequence_manager.request_timeout,
+                    timeout=sequence_manager.config.request_timeout,
                     metadata=MSGPackSerializer.dumps(metadata),
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
                 )
 
 
@@ -93,9 +93,8 @@ async def sequential_forward(
                 sequence_manager.on_request_success(span.peer_id)
                 sequence_manager.on_request_success(span.peer_id)
                 break
                 break
             except Exception as e:
             except Exception as e:
-                if span is not None:
-                    sequence_manager.on_request_failure(span.peer_id)
-                if attempt_no + 1 == sequence_manager.max_retries:
+                sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                if attempt_no + 1 == sequence_manager.config.max_retries:
                     raise
                     raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                 logger.warning(
@@ -152,7 +151,7 @@ async def sequential_backward(
                     span = forward_sequences.pop()
                     span = forward_sequences.pop()
 
 
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
                 span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
-                stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
+                stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
                 metadata = sequence_manager.get_request_metadata(
                 metadata = sequence_manager.get_request_metadata(
                     "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
                     "rpc_backward", span_uids, *inputs, *grad_outputs, peer_id=span.peer_id
                 )
                 )
@@ -163,7 +162,7 @@ async def sequential_backward(
                     inputs,
                     inputs,
                     grad_outputs,
                     grad_outputs,
                     prompts[span.start : span.end],
                     prompts[span.start : span.end],
-                    timeout=sequence_manager.request_timeout,
+                    timeout=sequence_manager.config.request_timeout,
                     metadata=MSGPackSerializer.dumps(metadata),
                     metadata=MSGPackSerializer.dumps(metadata),
                 )
                 )
                 grad_outputs = [grad_outputs]
                 grad_outputs = [grad_outputs]
@@ -171,9 +170,8 @@ async def sequential_backward(
                 sequence_manager.on_request_success(span.peer_id)
                 sequence_manager.on_request_success(span.peer_id)
                 break
                 break
             except Exception as e:
             except Exception as e:
-                if span is not None:
-                    sequence_manager.on_request_failure(span.peer_id)
-                if attempt_no + 1 == sequence_manager.max_retries:
+                sequence_manager.on_request_failure(span.peer_id if span is not None else None)
+                if attempt_no + 1 == sequence_manager.config.max_retries:
                     raise
                     raise
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 delay = sequence_manager.get_retry_delay(attempt_no)
                 logger.warning(
                 logger.warning(

+ 0 - 61
src/petals/dht_utils.py

@@ -71,67 +71,6 @@ async def _declare_active_modules(
     )
     )
 
 
 
 
-def get_remote_sequence(
-    dht: DHT,
-    start: int,
-    stop: int,
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-    return_future: bool = False,
-) -> Union[petals.client.RemoteSequential, MPFuture]:
-    return RemoteExpertWorker.run_coroutine(
-        _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
-    )
-
-
-async def _get_remote_sequence(
-    dht: DHT,
-    start: int,
-    stop: int,
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-) -> petals.client.RemoteSequential:
-    uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
-    p2p = await dht.replicate_p2p()
-    manager = petals.client.RemoteSequenceManager(dht, uids, p2p)
-    return petals.client.RemoteSequential(config, dht, dht_prefix, p2p, manager)
-
-
-def get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-    return_future: bool = False,
-) -> Union[Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]], MPFuture]:
-    """
-    :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param config: model config, usually taken by .from_pretrained(MODEL_NAME)
-    :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock]
-    """
-    return RemoteExpertWorker.run_coroutine(
-        _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
-    )
-
-
-async def _get_remote_module(
-    dht: DHT,
-    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    config: petals.client.DistributedBloomConfig,
-    dht_prefix: Optional[str] = None,
-) -> Union[petals.client.RemoteTransformerBlock, List[petals.client.RemoteTransformerBlock]]:
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    p2p = await dht.replicate_p2p()
-    managers = (petals.client.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
-    modules = [
-        petals.client.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m)
-        for m in managers
-    ]
-    return modules[0] if single_uid else modules
-
-
 def get_remote_module_infos(
 def get_remote_module_infos(
     dht: DHT,
     dht: DHT,
     uids: Sequence[ModuleUID],
     uids: Sequence[ModuleUID],

+ 4 - 9
tests/test_block_exact_match.py

@@ -1,28 +1,24 @@
 import random
 import random
 from typing import Union
 from typing import Union
 
 
-import hivemind
 import pytest
 import pytest
 import torch
 import torch
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.models.bloom.configuration_bloom import BloomConfig
 
 
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.block import WrappedBloomBlock
 from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
 from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
-from petals.client import DistributedBloomConfig
-from petals.client.remote_sequential import RemoteTransformerBlock
+from petals.client import DistributedBloomConfig, RemoteSequential
 from petals.data_structures import UID_DELIMITER
 from petals.data_structures import UID_DELIMITER
-from petals.dht_utils import get_remote_module
 from test_utils import *
 from test_utils import *
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
 def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_sequential = RemoteSequential(config)
 
 
     for block_index in random.sample(range(config.n_layer), 3):
     for block_index in random.sample(range(config.n_layer), 3):
-        remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
-        assert isinstance(remote_block, RemoteTransformerBlock)
+        remote_block = remote_sequential[block_index]
 
 
         inputs = torch.randn(1, 8, config.hidden_size)
         inputs = torch.randn(1, 8, config.hidden_size)
         outputs_forward = remote_block(inputs)
         outputs_forward = remote_block(inputs)
@@ -36,7 +32,6 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
             with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
             with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
                 sess.step(inputs[:, -1:, :])
                 sess.step(inputs[:, -1:, :])
             assert "Maximum length exceeded" in repr(exc_info.value)
             assert "Maximum length exceeded" in repr(exc_info.value)
-
         outputs_inference = torch.cat(outputs_inference, dim=1)
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)

+ 4 - 9
tests/test_chained_calls.py

@@ -4,22 +4,19 @@
 # - if you want to figure out chained inference, ask yozh
 # - if you want to figure out chained inference, ask yozh
 
 
 
 
-import hivemind
 import pytest
 import pytest
 import torch
 import torch
 
 
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.bloom.from_pretrained import load_pretrained_block
 from petals.client import DistributedBloomConfig
 from petals.client import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.remote_sequential import RemoteSequential
-from petals.dht_utils import get_remote_sequence
 from test_utils import *
 from test_utils import *
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_blocks = get_remote_sequence(dht, 3, 6, config)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_blocks = RemoteSequential(config, start_block=3, end_block=6)
     assert isinstance(remote_blocks, RemoteSequential)
     assert isinstance(remote_blocks, RemoteSequential)
 
 
     ref_blocks = [
     ref_blocks = [
@@ -46,10 +43,8 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
 def test_chained_inference_exact_match(atol_inference=1e-4):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    remote_blocks = get_remote_sequence(dht, 3, 5, config)
-    assert isinstance(remote_blocks, RemoteSequential)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    remote_blocks = RemoteSequential(config, start_block=3, end_block=5)
 
 
     inputs = torch.randn(1, 8, config.hidden_size)
     inputs = torch.randn(1, 8, config.hidden_size)
 
 

+ 3 - 4
tests/test_remote_sequential.py

@@ -20,7 +20,7 @@ def test_remote_sequential():
     test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
     test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
     grad_proj = torch.randn(1, 5, config.hidden_size)
     grad_proj = torch.randn(1, 5, config.hidden_size)
 
 
-    sequential = RemoteSequential(config, dht)
+    sequential = RemoteSequential(config, dht=dht)
 
 
     full_outputs = sequential(test_inputs)
     full_outputs = sequential(test_inputs)
     (full_outputs * grad_proj).sum().backward()
     (full_outputs * grad_proj).sum().backward()
@@ -48,7 +48,7 @@ def test_remote_sequential():
     # test RemoteSequential with lossy compression
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     lossy_sequential = RemoteSequential(
     lossy_sequential = RemoteSequential(
-        config, dht, sequence_manager=DummyCustomSequenceManager(dht, block_uids, sequential.p2p)
+        config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
     )
     )
 
 
     test_inputs.grad = None
     test_inputs.grad = None
@@ -85,8 +85,7 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
 @pytest.mark.forked
 @pytest.mark.forked
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
 def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
-    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
-    remote_sequential = RemoteSequential(config, dht)
+    remote_sequential = RemoteSequential(config)
 
 
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)

+ 2 - 3
tests/test_sequence_manager.py

@@ -18,15 +18,14 @@ logger = get_logger(__name__)
 def test_sequence_manager_basics(mode: str):
 def test_sequence_manager_basics(mode: str):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
-    sequential = RemoteSequential(config, dht)
+    sequential = RemoteSequential(config, dht=dht)
     shutdown_evt = threading.Event()
     shutdown_evt = threading.Event()
 
 
     # test RemoteSequential with lossy compression
     # test RemoteSequential with lossy compression
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
     sequential = RemoteSequential(
     sequential = RemoteSequential(
         config,
         config,
-        dht,
-        sequence_manager=TestSequenceManager(dht, block_uids, sequential.p2p, _was_shut_down=shutdown_evt),
+        sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
     )
     )
 
 
     sequence = sequential.sequence_manager.make_sequence(mode=mode)
     sequence = sequential.sequence_manager.make_sequence(mode=mode)

+ 7 - 8
tests/test_server_stats.py

@@ -4,34 +4,33 @@ import hivemind
 import pytest
 import pytest
 import torch
 import torch
 
 
-from petals.client import DistributedBloomConfig
+from petals.client import DistributedBloomConfig, RemoteSequential
 from petals.data_structures import UID_DELIMITER
 from petals.data_structures import UID_DELIMITER
-from petals.dht_utils import get_remote_sequence
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 from test_utils import *
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
 def test_server_info(block_from: int = 22, block_to: int = 24, max_length: int = 100, max_length2: int = 50):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    blocks1 = RemoteSequential(config, dht=dht, start_block=block_from, end_block=block_to)
+    blocks2 = RemoteSequential(config, dht=dht, start_block=block_to - 1, end_block=block_to)
 
 
-    blocks1 = get_remote_sequence(dht, block_from, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
-    blocks2 = get_remote_sequence(dht, block_to - 1, block_to, config, f"{MODEL_NAME}{UID_DELIMITER}")
     info_before = blocks1.sequence_manager.rpc_info
     info_before = blocks1.sequence_manager.rpc_info
 
 
     with blocks1.inference_session(max_length=max_length) as sess:
     with blocks1.inference_session(max_length=max_length) as sess:
         sess.step(torch.randn(1, 1, config.hidden_size))
         sess.step(torch.randn(1, 1, config.hidden_size))
-        blocks1.sequence_manager._rpc_info = None  # invalidate cache
+        blocks1.sequence_manager.state.rpc_info = None  # invalidate cache
         info_inside = blocks1.sequence_manager.rpc_info
         info_inside = blocks1.sequence_manager.rpc_info
 
 
         with blocks2.inference_session(max_length=max_length2) as sess2:
         with blocks2.inference_session(max_length=max_length2) as sess2:
             sess2.step(torch.randn(1, 1, config.hidden_size))
             sess2.step(torch.randn(1, 1, config.hidden_size))
-            blocks2.sequence_manager._rpc_info = None  # invalidate cache
+            blocks2.sequence_manager.state.rpc_info = None  # invalidate cache
             info_inside2 = blocks2.sequence_manager.rpc_info
             info_inside2 = blocks2.sequence_manager.rpc_info
 
 
     time.sleep(0.1)
     time.sleep(0.1)
-    blocks1.sequence_manager._rpc_info = None  # invalidate cache
+    blocks1.sequence_manager.state.rpc_info = None  # invalidate cache
     info_after = blocks1.sequence_manager.rpc_info
     info_after = blocks1.sequence_manager.rpc_info
 
 
     assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]
     assert info_before[CACHE_TOKENS_AVAILABLE] == info_after[CACHE_TOKENS_AVAILABLE]