Browse Source

Merge branch 'main' into priority-tasks

Pavel Samygin 3 years ago
parent
commit
aea707088b

+ 102 - 21
README.md

@@ -1,13 +1,86 @@
-# PETALS: Collaborative Inference of Large Models
+<p align="center">
+    <img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
+    Decentralized platform for running 100B+ language models<br><br>
+    <a href="https://github.com/bigscience-workshop/petals/actions">
+        <img src="https://github.com/bigscience-workshop/petals/actions/workflows/run-tests.yaml/badge.svg?branch=main">
+    </a>
+    <a href="https://github.com/psf/black">
+        <img src="https://img.shields.io/badge/code%20style-black-000000.svg">
+    </a>
+</p>
 
-Run BLOOM-176B, the largest open language model, by collaborating over the Internet.
+## Key features
 
-__[EARLY PROTOTYPE]__ - this project is a work in progress. Stuff breaks and gets fixed every day. Docs are nonexistent.
-If you want us to wake you up when it's ready, click Watch -> Custom and tick "Releases".
+- Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs.
+- It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning.
+- This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps.
+- Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.
 
-Roadmap: [__Issue #12__](https://github.com/learning-at-home/bloom-demo/issues/12)
+<p align="center">
+    <b><a href="https://arxiv.org/pdf/2209.01188.pdf">[Read paper]</a></b> | <b><a href="https://petals.ml/">[View website]</a></b>
+</p>
 
-### Installation
+## How it works?
+
+<p align="center">
+    <img src="https://i.imgur.com/RTYF3yW.png" width="800">
+</p>
+
+### 🛠️ Examples
+
+Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
+
+This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
+
+```python
+# Initialize distributed BLOOM and connect to the swarm
+model = DistributedBloomForCausalLM.from_pretrained(
+    "bigscience/distributed-bloom", tuning_mode="ptune", initial_peers=SEE_BELOW
+)  # Embeddings & prompts are on your device, BLOOM blocks are distributed
+
+print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
+
+# Training (updates only local prompts / adapters)
+optimizer = torch.optim.AdamW(model.parameters())
+for input_ids, labels in data_loader:
+    outputs = model.forward(input_ids)
+    loss = cross_entropy(outputs.logits, labels)
+    optimizer.zero_grad()
+    loss.backward()
+    optimizer.step()
+```
+
+### 🚧 This project is in active development
+
+Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)).
+
+A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm.
+
+### 🔒 Privacy and security
+
+If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data.
+
+This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
+
+## FAQ
+
+1. **What's the motivation for people to host model layers in the public swarm?**
+
+    People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded).
+
+    Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards.
+
+2. **Why is the platform named "Petals"?**
+
+    "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model &mdash; [BLOOM](https://huggingface.co/bigscience/bloom).
+
+    While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future.
+
+## Installation
+
+🚧 **Note:** These are short instructions for running a private swarm with a test 6B version of BLOOM. We will replace them with instructions involving the full 176B BLOOM and more detailed explanations soon (in a day or two).
+
+--------------------------------------------------------------------------------
 
 ```bash
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
@@ -16,7 +89,6 @@ pip install -r requirements.txt
 pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
 ```
 
-
 ### Basic functionality
 
 All tests is run on localhost
@@ -37,18 +109,18 @@ Then open a python notebook or console and run:
 ```python
 import torch
 import hivemind
-from src import get_remote_module
+from src import DistributedBloomConfig, get_remote_module
 
 
 dht = hivemind.DHT(
     initial_peers=[TODO_COPY_FULL_ADDRESS_FROM_ANY_OF_THE_SERVERS],  # e.g. /ip4/127.0.0.1/...
     client_mode=True, start=True,
 )
-
-layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'])
+config = DistributedBloomConfig.from_pretrained("bigscience/test-bloom-6b3")
+layer3, layer4 = get_remote_module(dht, ['bigscience/test-bloomd-6b3.3', 'bigscience/test-bloomd-6b3.4'], config)
 assert layer3 is not None and layer4 is not None, "one or both layers were not found in DHT"
 # test forward/backward, two blocks
-outputs, = layer4(*layer3(torch.randn(1, 64, 4096)))
+outputs = layer4(layer3(torch.randn(1, 64, 4096)))
 loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
@@ -74,18 +146,27 @@ python -m cli.convert_model --model bigscience/bloom-6b3  \
 
 To test distributed inference, run one or more servers, then open a new shell and run pytest with environment variables:
 ```bash
-# shell A: serve blocks 3 and 4
+# shell A: serve model
 python -m cli.run_server --converted_model_name_or_path bigscience/test-bloomd-6b3 \
-  --block_indices 3:5 --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
+  --torch_dtype float32 --identity_path ./server1.id --host_maddrs /ip4/127.0.0.1/tcp/31337
 
-# shell B: connect to the swarm and test individual blocks for exact match
-export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
-BLOCK_UID=bigscience/test-bloomd-6b3.3 pytest tests/test_block_exact_match.py
-BLOCK_UID=bigscience/test-bloomd-6b3.4 pytest tests/test_block_exact_match.py
+# shell B:
+export PYTHONPATH=.
+export INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
+export MODEL_NAME="bigscience/test-bloomd-6b3"
 
-# the test below will fail because there is no server that serves layer 7
-# BLOCK_UID=bigscience/test-bloomd-6b3.7 pytest tests/test_block_exact_match.py
+# test individual random blocks for exact match
+pytest tests/test_block_exact_match.py
 
-# test the full model (requires that servers collectively serve all model layers)
-REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py
+# test the full model
+pytest tests/test_full_model.py
 ```
+
+--------------------------------------------------------------------------------
+
+<p align="center">
+    This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
+</p>
+<p align="center">
+    <img src="https://petals.ml/bigscience.png" width="150">
+</p>

+ 1 - 2
src/client/__init__.py

@@ -1,7 +1,6 @@
 from src.client.dust_bank import DummyDustBank, DustBankBase
 from src.client.dusty_block import DustyRemoteBlock
 from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequential import RemoteSequential
+from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager

+ 0 - 40
src/client/remote_block.py

@@ -1,40 +0,0 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-from __future__ import annotations
-
-import random
-
-import torch
-from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
-from hivemind.p2p import P2P, StubBase
-from hivemind.utils import get_logger, use_hivemind_log_handler
-
-from src.client.inference_session import RemoteTransformerBlockInferenceSession
-from src.data_structures import RemoteModuleInfo
-from src.server.handler import TransformerConnectionHandler
-
-use_hivemind_log_handler("in_root_logger")
-logger = get_logger(__file__)
-
-
-class RemoteTransformerBlock(RemoteExpert):
-    """A class that interacts with a remote module on a specific server for forward/backward or inference"""
-
-    def __init__(self, peers_info: RemoteModuleInfo, p2p: P2P):
-        peer_info = ExpertInfo(peers_info.uid, random.choice(list(peers_info.servers.keys())))  # TODO replace this
-        super().__init__(peer_info, p2p)
-
-    @property
-    def stub(self) -> StubBase:
-        return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
-
-    def forward(self, inputs: torch.Tensor, **kwargs):
-        for k, v in kwargs.items():
-            assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
-        return super().forward(inputs)
-
-    def inference_session(self, **kwargs) -> RemoteTransformerBlockInferenceSession:
-        """Initialize a new inference session with the specified remote server"""
-        return RemoteExpertWorker.run_coroutine(
-            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info, **kwargs)
-        )

+ 4 - 0
src/client/remote_generation.py

@@ -63,6 +63,7 @@ class RemoteGenerationMixin:
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
+        prefix_length += self.config.pre_seq_len
 
         bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
@@ -104,6 +105,9 @@ class RemoteGenerationMixin:
             hypo_ids = torch.arange(outputs[0].size(0))
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
+                if self.config.pre_seq_len > 0 and len(outputs) == 1:
+                    prompts, _ = self.transformer.get_prompt(embs.size(0))
+                    embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
                 hidden_state = sess.step(embs)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)

+ 23 - 7
src/client/remote_sequential.py

@@ -1,6 +1,5 @@
 from __future__ import annotations
 
-import logging
 from typing import Optional, Union
 
 import torch
@@ -10,11 +9,9 @@ from torch import nn
 
 import src
 from src.client.inference_session import RemoteSequentialInferenceSession
-from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from src.data_structures import UID_DELIMITER
-from src.dht_utils import _create_remote_modules_from_infos
 from src.utils.misc import DUMMY
 
 use_hivemind_log_handler("in_root_logger")
@@ -57,12 +54,16 @@ class RemoteSequential(nn.Module):
         outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
         return outputs
 
-    def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
-            assert 0 <= ix < len(self)
-            (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
-            return module
+            return RemoteTransformerBlock(
+                self.config,
+                self.dht,
+                dht_prefix=self.dht_prefix,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
         else:
             return RemoteSequential(
                 self.config,
@@ -85,3 +86,18 @@ class RemoteSequential(nn.Module):
 
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
+
+
+class RemoteTransformerBlock(RemoteSequential):
+    """Single transformer block hosted by swarm
+
+    This class is deprecated and kept for backward compatibility.
+    It will be removed soon in favor of using ``RemoteSequential`` directly.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert len(self) == 1, "Remote Block is a sequence size 1"
+
+    def extra_repr(self):
+        return f"{self.sequence_manager.block_uids[0]}"

+ 12 - 9
src/client/sequence_manager.py

@@ -82,6 +82,7 @@ class RemoteSequenceManager:
         for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
             if info is None:
                 logger.warning(f"Found no block info for block {uid}")
+                continue
             if not isinstance(info, RemoteModuleInfo):
                 logger.warning(f"Unexpected dht entry type for {uid}: {info}")
             if not info.servers:
@@ -95,22 +96,24 @@ class RemoteSequenceManager:
         closed_spans = []
         active_spans = {}
         for block_index, info in enumerate(block_infos):
-            for peer_id, server in info.servers.items():
-                if server.state != ServerState.ONLINE:
-                    continue
-                if peer_id not in active_spans:
-                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
-                else:  # peer_id in active_spans
-                    active_spans[peer_id].end = block_index + 1
+            if info is not None:
+                for peer_id, server in info.servers.items():
+                    if server.state != ServerState.ONLINE:
+                        continue
+                    if peer_id not in active_spans:
+                        active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    else:  # peer_id in active_spans
+                        active_spans[peer_id].end = block_index + 1
 
             for peer_id in list(active_spans.keys()):
                 if (
-                    peer_id not in info.servers
+                    info is None
+                    or peer_id not in info.servers
                     or info.servers[peer_id].state != ServerState.ONLINE
                     or block_index == len(block_infos) - 1
                 ):
                     closed_spans.append(active_spans.pop(peer_id))
-        assert not active_spans
+        assert not active_spans, f"spans: {active_spans}"
 
         closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
 

+ 1 - 1
src/client/sequential_autograd.py

@@ -110,7 +110,7 @@ async def sequential_forward(
     If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
     """
 
-    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3
+    assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
 
     end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
     assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)

+ 47 - 30
src/dht_utils.py

@@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Sequence, Union
 
 from hivemind.dht import DHT, DHTNode, DHTValue
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.p2p import P2P, PeerID
+from hivemind.p2p import PeerID
 from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
 
 import src
@@ -72,34 +72,63 @@ async def _declare_active_modules(
     )
 
 
+def get_remote_sequence(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+    return_future: bool = False,
+) -> Union[src.RemoteSequential, MPFuture]:
+    return RemoteExpertWorker.run_coroutine(
+        _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
+    )
+
+
+async def _get_remote_sequence(
+    dht: DHT,
+    start: int,
+    stop: int,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+) -> src.RemoteSequential:
+    uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
+    p2p = await dht.replicate_p2p()
+    manager = src.RemoteSequenceManager(dht, uids, p2p)
+    return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
+
+
 def get_remote_module(
     dht: DHT,
     uid_or_uids: Union[ModuleUID, List[ModuleUID]],
-    expiration_time: Optional[DHTExpiration] = None,
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
     return_future: bool = False,
-) -> Union[List[Optional[src.RemoteTransformerBlock]], MPFuture[List[Optional[src.RemoteTransformerBlock]]]]:
+) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
     """
     :param uid_or_uids: find one or more modules with these ids from across the DHT
-    :param expiration_time: if specified, return modules that expire no sooner than this (based on get_dht_time)
+    :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
     :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
-    :returns: a list of [RemoteTransformerBlock if found else None]
+    :returns: a list of [RemoteTransformerBlock]
     """
-    single_uid = isinstance(uid_or_uids, ModuleUID)
-    uids = [uid_or_uids] if single_uid else uid_or_uids
-    infos = dht.run_coroutine(
-        partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future
+    return RemoteExpertWorker.run_coroutine(
+        _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
     )
 
-    if return_future:
-
-        async def _unpack(infos_future: MPFuture, dht: DHT):
-            p2p = await dht.replicate_p2p()
-            modules = _create_remote_modules_from_infos(await infos_future, p2p)
-            return modules[0] if single_uid else modules
 
-        return RemoteExpertWorker.run_coroutine(_unpack(infos, dht), return_future)
-    p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-    modules = _create_remote_modules_from_infos(infos, p2p)
+async def _get_remote_module(
+    dht: DHT,
+    uid_or_uids: Union[ModuleUID, List[ModuleUID]],
+    config: src.DistributedBloomConfig,
+    dht_prefix: Optional[str] = None,
+) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
+    single_uid = isinstance(uid_or_uids, ModuleUID)
+    uids = [uid_or_uids] if single_uid else uid_or_uids
+    p2p = await dht.replicate_p2p()
+    managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
+    modules = [
+        src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
+    ]
     return modules[0] if single_uid else modules
 
 
@@ -149,15 +178,3 @@ async def _get_remote_module_infos(
         if servers:
             modules[i] = RemoteModuleInfo(uid, servers)
     return modules
-
-
-def _create_remote_modules_from_infos(
-    infos: Sequence[Optional[RemoteModuleInfo]], p2p: P2P
-) -> List[Optional[src.RemoteTransformerBlock]]:
-    modules: List[Optional[src.RemoteTransformerBlock]] = []
-    for info in infos:
-        if info is not None:
-            modules.append(src.RemoteTransformerBlock(info, p2p))
-        else:
-            modules.append(None)
-    return modules

+ 6 - 6
tests/test_block_exact_match.py

@@ -7,8 +7,10 @@ import transformers
 from hivemind import P2PHandlerError
 from test_utils import *
 
+import src
+from src import DistributedBloomConfig
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
+from src.client.remote_sequential import RemoteTransformerBlock
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 
@@ -16,16 +18,14 @@ from src.dht_utils import get_remote_module
 @pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
 
     for block_index in random.sample(range(config.n_layer), 3):
-        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
-        remote_block = get_remote_module(dht, block_uid)
-        assert remote_block is not None, f"Could not find {block_uid} in DHT"
+        remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
         assert isinstance(remote_block, RemoteTransformerBlock)
 
         inputs = torch.randn(1, 8, config.hidden_size)
-        (outputs_forward,) = remote_block(inputs)
+        outputs_forward = remote_block(inputs)
 
         outputs_inference = []
         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:

+ 11 - 20
tests/test_chained_calls.py

@@ -7,25 +7,20 @@
 import hivemind
 import pytest
 import torch
-import transformers
-from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from test_utils import *
 
+import src
 from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
+from src.client.remote_sequential import RemoteSequential
+from src.dht_utils import get_remote_sequence
 
 
 @pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
-    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_blocks = get_remote_sequence(dht, 3, 6, config)
+    assert isinstance(remote_blocks, RemoteSequential)
 
     ref_blocks = [
         load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
@@ -33,7 +28,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
         load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
-    outputs_rpc = remote_block.forward(inputs)[0]
+    outputs_rpc = remote_blocks.forward(inputs)
     outputs_rpc.sum().backward()
     grads_rpc = inputs.grad
 
@@ -52,18 +47,14 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
 @pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
-    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
-    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
+    config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    remote_blocks = get_remote_sequence(dht, 3, 5, config)
+    assert isinstance(remote_blocks, RemoteSequential)
 
     inputs = torch.randn(1, 8, config.hidden_size)
 
     outputs_inference = []
-    with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+    with remote_blocks.inference_session(max_length=inputs.shape[1]) as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)