Browse Source

Implement RemoteSequential slicing and extra repr, add tests (#30)

- finish renaming RemoteSequenceInfo -> RemoteSequenceManager (why: if it was an *Info, user would expect it to be similar - to a dataclass; whereas in actuality, the class is doing heavy network interactions on its own)
- implement RemoteSequenceManager.make_sequence (from https://pastebin.com/uXgy2U8B )
- make RemoteSequentialInferenceSession use RemoteSequenceManager.make_sequence
- make tests pass again
- make it possible to create inference session without RemoteTransformerBlock
- make a standalone test for RemoteSequential
- rollback convert-model

Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
justheuristic 3 years ago
parent
commit
f0c7383181

+ 3 - 8
.github/workflows/run-tests.yaml

@@ -66,6 +66,8 @@ jobs:
         run: |
         run: |
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
+          export REF_NAME=bigscience/bloom-350m
+
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
           SERVER1_PID=$!
           SERVER1_PID=$!
@@ -79,14 +81,7 @@ jobs:
 
 
           sleep 30  # wait for server to download layers
           sleep 30  # wait for server to download layers
           
           
-          # test individual blocks
-          export PYTHONPATH=.
-          BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
-          BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
-
-          REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
-          
-          REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
+          PYTHONPATH=. pytest tests
           
           
           kill -s SIGINT $SERVER1_PID $SERVER2_PID
           kill -s SIGINT $SERVER1_PID $SERVER2_PID
           echo "Done!"
           echo "Done!"

+ 2 - 1
src/client/__init__.py

@@ -1,4 +1,5 @@
-from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
+from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager

+ 173 - 0
src/client/inference_session.py

@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from typing import AsyncIterator, List, Optional
+
+import torch
+from hivemind import (
+    P2P,
+    anext,
+    deserialize_torch_tensor,
+    get_logger,
+    nested_flatten,
+    serialize_torch_tensor,
+    use_hivemind_log_handler,
+)
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteTransformerBlockInferenceSession:
+    """
+    An interface to a single multi-step *inference* session for a specific remote module on a specific server
+
+    :note: this inference session is *not* fault-tolerant out of the box
+    """
+
+    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.rpc_info = uid, rpc_info
+        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.stepped = False
+        self.closed = False
+
+    @classmethod
+    async def _create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
+    ) -> RemoteTransformerBlockInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(self, new_hidden_states: torch.Tensor):
+        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states,)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(
+            self._step(
+                runtime_pb2.ExpertRequest(
+                    uid=self.uid,
+                    tensors=[
+                        serialize_torch_tensor(tensor, proto.compression)
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                    ],
+                )
+            )
+        )
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
+        return await anext(self._outputs_stream)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()
+
+
+class RemoteSequentialInferenceSession:
+    """
+    An interface to a multi-step *inference* session for a sequence of remote transformer blocks
+    """
+
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
+        self.sequence_manager = sequence_manager
+        self.p2p = p2p
+        self.closed = False
+        self.chosen_spans: List[RemoteSpanInfo] = []
+        self.stack = contextlib.ExitStack()
+        self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.timeout = timeout
+
+    def __enter__(self):
+        assert not self.closed and not self.chosen_spans
+        self.stack.__enter__()
+        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+        self.chosen_spans.extend(self.sequence_manager.make_sequence())
+
+        for chosen_span in self.chosen_spans:
+            stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
+            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
+            inference_session = RemoteExpertWorker.run_coroutine(
+                RemoteTransformerBlockInferenceSession._create(
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+                )
+            )
+            self.inference_sessions.append(inference_session)
+            self.stack.enter_context(inference_session)
+
+        return self
+
+    def step(self, inputs: torch.Tensor):
+        assert not self.closed
+        if torch.is_grad_enabled():
+            logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        for session in self.inference_sessions:
+            outputs = session.step(inputs)
+            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+            inputs = outputs
+        return inputs
+
+    def close(self, *exc_details):
+        """Finish a given inference session, close the underlying connection"""
+        if not self.closed:
+            self.stack.__exit__(*exc_details or (None, None, None))
+            self.inference_sessions.clear()
+            self.closed = True
+
+    def __exit__(self, *exc_details):
+        self.close(*exc_details)
+
+    def __del__(self):
+        self.close()

+ 5 - 99
src/client/remote_block.py

@@ -1,20 +1,16 @@
 # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 from __future__ import annotations
 from __future__ import annotations
 
 
-import asyncio
 import random
 import random
-from typing import Any, AsyncIterator, Dict, Optional
 
 
 import torch
 import torch
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.p2p import P2P, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
+from hivemind.utils import get_logger, use_hivemind_log_handler
 
 
+from src.client.inference_session import RemoteTransformerBlockInferenceSession
 from src.data_structures import RemoteModuleInfo
 from src.data_structures import RemoteModuleInfo
-from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 from src.server.handler import TransformerConnectionHandler
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
@@ -39,100 +35,10 @@ class RemoteTransformerBlock(RemoteExpert):
 
 
     def inference_session(self) -> RemoteTransformerBlockInferenceSession:
     def inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         """Initialize a new inference session with the specified remote server"""
-        _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
-        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
+        return RemoteExpertWorker.run_coroutine(
+            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
+        )
 
 
     def begin_inference_session(self):
     def begin_inference_session(self):
         logger.warning("beging_inference_session was renamed to just inference_session")
         logger.warning("beging_inference_session was renamed to just inference_session")
         return self.inference_session()
         return self.inference_session()
-
-
-class RemoteTransformerBlockInferenceSession:
-    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
-
-    def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
-        self.uid, self.info = uid, info
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
-        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.stepped = False
-        self.closed = False
-
-    @classmethod
-    async def _create(
-        cls,
-        remote_module: RemoteTransformerBlock,
-        timeout: Optional[float] = None,
-    ) -> RemoteTransformerBlockInferenceSession:
-        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        inputs_queue = asyncio.Queue()
-        outputs_stream = await remote_module.stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout),
-            timeout=timeout,
-        )
-        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
-
-    @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
-        while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
-            yield next_input_message
-            if not next_input_message.uid and not next_input_message.tensors:
-                break  # this message means "done sending"
-
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
-        if self.closed:
-            raise Exception("Session is closed, cannot perform step")
-        # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
-        outputs_serialized = RemoteExpertWorker.run_coroutine(
-            self._step(
-                runtime_pb2.ExpertRequest(
-                    uid=self.uid,
-                    tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
-                    ],
-                )
-            )
-        )
-        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
-
-    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
-        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
-        await self._inputs_queue.put(inputs_serialized)
-        self.stepped = True
-        return await anext(self._outputs_stream)
-
-    def close(self):
-        """Finish a given inference session, close the underlying connection"""
-        if self._outputs_stream is None:
-            return  # already closed
-        RemoteExpertWorker.run_coroutine(self._aclose_stream())
-        self._outputs_stream = self._inputs_queue = None
-        self.closed = True
-
-    async def _aclose_stream(self):
-        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
-        if self._outputs_stream is None:
-            return  # already closed
-        if self.stepped:
-            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-            try:
-                await anext(self._outputs_stream)
-            except StopAsyncIteration:
-                pass
-
-    def __del__(self):
-        self.close()
-
-    def __enter__(self):
-        assert not self.closed
-        return self
-
-    def __exit__(self, *exc_details):
-        self.close()

+ 16 - 78
src/client/remote_sequential.py

@@ -1,17 +1,15 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
-import contextlib
 import logging
 import logging
-import random
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 import torch
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
 from torch import nn
 from torch import nn
 
 
 import src
 import src
+from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
@@ -30,49 +28,41 @@ class RemoteSequential(nn.Module):
         self,
         self,
         config: src.DistributedBloomConfig,
         config: src.DistributedBloomConfig,
         dht: DHT,
         dht: DHT,
-        prefix: str,
-        max_retries: int = 3,
+        dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
     ):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
-        if prefix.endswith(UID_DELIMITER):
-            logger.warning(
-                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
-                f"This will cause {self.__class__.__name__} to look for modules under "
-                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
-            )
-
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.dht = dht
         self.dht = dht
-        self.prefix = prefix
-        self.max_retries = max_retries
+        self.dht_prefix = dht_prefix or config.dht_prefix
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
 
-        block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
+        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
         if sequence_manager is None:
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
             self.is_subsequence = False
             self.is_subsequence = False
         else:
         else:
+            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+            self.sequence_manager = sequence_manager
             assert isinstance(sequence_manager.block_uids, list)
             assert isinstance(sequence_manager.block_uids, list)
-            logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
             self.is_subsequence = self.sequence_manager.block_uids == block_uids
             self.is_subsequence = self.sequence_manager.block_uids == block_uids
 
 
     def forward(self, inputs: torch.Tensor):
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block_index in range(self.config.n_layer):
-            for retry_index in range(self.max_retries):
+        for block in iter(self):
+            for retry_index in range(self.sequence_manager.max_retries):
                 try:
                 try:
-                    block = self[block_index]
                     (outputs,) = block(inputs)
                     (outputs,) = block(inputs)
                     assert isinstance(outputs, torch.Tensor)
                     assert isinstance(outputs, torch.Tensor)
                     assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
                     assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
                     inputs = outputs
                     inputs = outputs
                     break
                     break
                 except Exception as e:
                 except Exception as e:
-                    if retry_index == self.max_retries - 1:
+                    if retry_index == self.sequence_manager.max_retries - 1:
                         raise e
                         raise e
                     else:
                     else:
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
@@ -81,21 +71,20 @@ class RemoteSequential(nn.Module):
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
         assert isinstance(ix, (int, slice))
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
         if isinstance(ix, int):
-            assert 0 <= ix < self.config.n_layer
+            assert 0 <= ix < len(self)
             (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
             (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
             return module
             return module
         else:
         else:
             return RemoteSequential(
             return RemoteSequential(
                 self.config,
                 self.config,
                 self.dht,
                 self.dht,
-                prefix=self.prefix,
-                max_retries=self.max_retries,
+                dht_prefix=self.dht_prefix,
                 p2p=self.p2p,
                 p2p=self.p2p,
                 sequence_manager=self.sequence_manager[ix],
                 sequence_manager=self.sequence_manager[ix],
             )
             )
 
 
     def __iter__(self):
     def __iter__(self):
-        for block_index in range(self.config.n_layer):
+        for block_index in range(len(self)):
             yield self[block_index]
             yield self[block_index]
 
 
     def __len__(self):
     def __len__(self):
@@ -105,56 +94,5 @@ class RemoteSequential(nn.Module):
         self.sequence_manager.update_()
         self.sequence_manager.update_()
         return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
         return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
 
 
-
-class RemoteSequentialInferenceSession:
-    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
-
-    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
-        self.remote_sequence_info = remote_sequence_info
-        self.p2p = p2p
-        self.closed = False
-        self.stack = contextlib.ExitStack()
-        self.active_sessions = []
-
-    def __enter__(self):
-        assert not self.closed
-        self.stack.__enter__()
-        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        current_block = 0
-        while current_block != len(self.remote_sequence_info):
-            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
-            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
-            assert chosen_span.start <= current_block < chosen_span.end
-
-            # TODO begin throwaway prototype code
-            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
-            _ = remote.info  # TODO fix
-            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
-            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-            self.active_sessions.append(remote.inference_session())
-            self.stack.enter_context(self.active_sessions[-1])
-            current_block = chosen_span.end
-            # TODO end throwaway prototype code
-
-        return self
-
-    def step(self, inputs: torch.Tensor):
-        assert not self.closed
-        for session in self.active_sessions:
-            outputs = session.step(inputs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-            inputs = outputs
-        return inputs
-
-    def close(self, *exc_details):
-        """Finish a given inference session, close the underlying connection"""
-        if not self.closed:
-            self.stack.__exit__(*exc_details or (None, None, None))
-            self.active_sessions.clear()
-            self.closed = True
-
-    def __exit__(self, *exc_details):
-        self.close(*exc_details)
-
-    def __del__(self):
-        self.close()
+    def extra_repr(self) -> str:
+        return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 62 - 19
src/client/sequence_manager.py

@@ -1,36 +1,37 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
+import random
 import threading
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
 from typing import List, Optional, Sequence, Tuple, Union
 
 
-from hivemind import DHT, DHTExpiration
+from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
+from src.server.handler import TransformerConnectionHandler
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
 class RemoteSequenceManager:
 class RemoteSequenceManager:
-    """Keeps and updates the meta-information about which peers host which blocks"""
-
-    dht: DHT
-    block_uids: List[ModuleUID]
-    block_infos: List[Optional[RemoteModuleInfo]]
-    spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
-    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
-    last_update_time: DHTExpiration
-    lock_changes: threading.Lock
-
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
-        self.dht = dht
-        self.block_uids = list(block_uids)
-        self.block_infos = [None] * len(self.block_uids)
-        self.spans_by_priority = []
-        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
-        self.last_update_time = -float("inf")
+    """
+    Keeps and updates the meta-information about which peers host which blocks.
+    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+    """
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+        self.dht, self.p2p = dht, p2p
+        self.block_uids: List[ModuleUID] = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
+        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
+        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+        self.last_update_time: DHTExpiration = -float("inf")
+        self.max_retries = max_retries
+        self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
         self.update_()
         self.update_()
 
 
@@ -38,13 +39,33 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
         assert self.spans_by_priority and self.spans_containing_block
 
 
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+        """
+        Form a sequence of remote servers that collectively serve all consecutive layers
+
+        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        """
+        end_index = end_index if end_index is not None else len(self.block_uids)
+        span_sequence = []
+        current_index = start_index
+        while current_index < end_index:
+            candidate_spans = self.spans_containing_block[current_index]
+            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(chosen_span)
+            current_index = chosen_span.end
+
+        return span_sequence
+
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
         assert isinstance(ix, (int, slice))
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
             ix = slice(int(ix), int(ix) + 1, 1)
         with self.lock_changes:
         with self.lock_changes:
-            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
             subseq.block_infos = self.block_infos[ix]
             subseq.block_infos = self.block_infos[ix]
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq.last_update_time = self.last_update_time
             subseq.last_update_time = self.last_update_time
@@ -102,3 +123,25 @@ class RemoteSequenceManager:
 
 
     def __len__(self):
     def __len__(self):
         return len(self.block_uids)
         return len(self.block_uids)
+
+    @property
+    def rpc_info(self):
+        """Return the rpc_info queried from one of the servers that hold the first block"""
+        if self._rpc_info is None:
+            retries = 0
+            for i in range(self.max_retries):
+                try:
+                    self.update_()
+                    peer_id = random.choice(list(self.block_infos[0].servers.keys()))
+                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+                    outputs = RemoteExpertWorker.run_coroutine(
+                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+                    )
+                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                except Exception as e:
+                    retries += 1
+                    if retries >= self.max_retries:
+                        raise e
+                    else:
+                        logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
+        return self._rpc_info

+ 4 - 1
src/data_structures.py

@@ -1,6 +1,6 @@
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
-from typing import Dict
+from typing import Any, Dict
 
 
 from hivemind import PeerID
 from hivemind import PeerID
 
 
@@ -36,3 +36,6 @@ class RemoteSpanInfo:
     start: int
     start: int
     end: int
     end: int
     peer_id: PeerID
     peer_id: PeerID
+
+
+RPCInfo = Dict[str, Any]

+ 0 - 1
src/server/handler.py

@@ -1,4 +1,3 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 import contextlib
 import contextlib
 from typing import AsyncIterator, Dict, Sequence
 from typing import AsyncIterator, Dict, Sequence
 
 

+ 51 - 0
tests/conftest.py

@@ -0,0 +1,51 @@
+import asyncio
+import gc
+from contextlib import suppress
+
+import psutil
+import pytest
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.mpfuture import MPFuture
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+@pytest.fixture
+def event_loop():
+    """
+    This overrides the ``event_loop`` fixture from pytest-asyncio
+    (e.g. to make it compatible with ``asyncio.subprocess``).
+
+    This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+    fails if the loop is closed, but works if the loop is only stopped).
+    """
+
+    yield asyncio.get_event_loop()
+
+
+@pytest.fixture(autouse=True, scope="session")
+def cleanup_children():
+    yield
+
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
+    gc.collect()  # Call .__del__() for removed objects
+
+    children = psutil.Process().children(recursive=True)
+    if children:
+        logger.info(f"Cleaning up {len(children)} leftover child processes")
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.terminate()
+        psutil.wait_procs(children, timeout=1)
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.kill()
+
+    MPFuture.reset_backend()

+ 22 - 30
tests/test_block_exact_match.py

@@ -1,47 +1,39 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
+import random
 
 
 import hivemind
 import hivemind
+import pytest
 import torch
 import torch
 import transformers
 import transformers
+from test_utils import *
 
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 from src.dht_utils import get_remote_module
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
-
 
 
+@pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
 
 
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-    ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
+    for block_index in random.sample(range(config.n_layer), 3):
+        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
+        remote_block = get_remote_module(dht, block_uid)
+        assert remote_block is not None, f"Could not find {block_uid} in DHT"
+        assert isinstance(remote_block, RemoteTransformerBlock)
 
 
-    inputs = torch.randn(1, 8, ref_config.hidden_size)
-    (outputs_forward,) = remote_block(inputs)
+        inputs = torch.randn(1, 8, config.hidden_size)
+        (outputs_forward,) = remote_block(inputs)
 
 
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
+        outputs_inference = []
+        with remote_block.inference_session() as sess:
+            for i in range(inputs.shape[1]):
+                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+        outputs_inference = torch.cat(outputs_inference, dim=1)
 
 
-    ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
-    (outputs_local,) = ref_block(inputs)
+        ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+        (outputs_local,) = ref_block(inputs)
 
 
-    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
-    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+        assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
+        assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 9 - 18
tests/test_chained_calls.py

@@ -3,30 +3,20 @@
 # - if you want more stable tests, see test_block_exact_match
 # - if you want more stable tests, see test_block_exact_match
 # - if you want to figure out chained inference, ask yozh
 # - if you want to figure out chained inference, ask yozh
 
 
-import os
 
 
 import hivemind
 import hivemind
+import pytest
 import torch
 import torch
 import transformers
 import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
+from test_utils import *
 
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
 from src.dht_utils import get_remote_module
 from src.dht_utils import get_remote_module
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
 
 
+@pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@@ -38,9 +28,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
 
 
     ref_blocks = [
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
     outputs_rpc = remote_block.forward(inputs)[0]
     outputs_rpc = remote_block.forward(inputs)[0]
@@ -59,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
     assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
 
 
 
 
+@pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@@ -78,8 +69,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     outputs_inference = torch.cat(outputs_inference, dim=1)
     outputs_inference = torch.cat(outputs_inference, dim=1)
 
 
     ref_blocks = [
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
     ]
     ]
     outputs_ref = []
     outputs_ref = []
     caches = [None, None]
     caches = [None, None]

+ 24 - 33
tests/test_full_model.py

@@ -1,9 +1,8 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
-
+import pytest
 import torch
 import torch
 import transformers
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from hivemind import get_logger, use_hivemind_log_handler
+from test_utils import *
 
 
 from src.client.remote_model import DistributedBloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
 
 
@@ -11,19 +10,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME")
-
-
+@pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
@@ -31,23 +18,12 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     assert len(model.transformer.h) == model.config.n_layer
     assert len(model.transformer.h) == model.config.n_layer
 
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    parallel_outputs = model.forward(test_inputs).logits
-    assert torch.all(torch.isfinite(parallel_outputs))
-    logger.info("Forward outputs are finite")
 
 
-    if REF_NAME:
-        with torch.no_grad():
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-            # note: this creates a dummy mask to make the test compatible with older transformer versions
-            # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-            del ref_model, ref_outputs
-    else:
-        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+    with torch.no_grad():
+        parallel_outputs = model.forward(test_inputs).logits
+        assert torch.all(torch.isfinite(parallel_outputs))
+        logger.info("Forward outputs are finite")
 
 
-    with torch.inference_mode():
         embs = model.transformer.word_embeddings(test_inputs)
         embs = model.transformer.word_embeddings(test_inputs)
         embs = model.transformer.word_embeddings_layernorm(embs)
         embs = model.transformer.word_embeddings_layernorm(embs)
         recurrent_outputs = []
         recurrent_outputs = []
@@ -60,5 +36,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         dictionary = model.transformer.word_embeddings.weight.t()
         dictionary = model.transformer.word_embeddings.weight.t()
         recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
         recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
         recurrent_outputs = (recurrent_outputs @ dictionary).float()
         recurrent_outputs = (recurrent_outputs @ dictionary).float()
-    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-    logger.info("Inference is consistent with forward")
+        assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+        logger.info("Inference is consistent with forward")
+
+        del model, recurrent_outputs
+
+        if REF_NAME:
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+            # note: this creates a dummy mask to make the test compatible with older transformer versions
+            # prior to https://github.com/huggingface/transformers/pull/17837
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+            logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
+            del ref_model, ref_outputs, dummy_mask
+        else:
+            logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+            assert False

+ 43 - 0
tests/test_remote_sequential.py

@@ -0,0 +1,43 @@
+import pytest
+import torch
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+from test_utils import *
+
+from src import RemoteSequential
+from src.client.remote_model import DistributedBloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@pytest.mark.forked
+def test_remote_sequential():
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
+    grad_proj = torch.randn(1, 5, config.hidden_size)
+
+    sequential = RemoteSequential(config, dht)
+
+    full_outputs = sequential(test_inputs)
+    (full_outputs * grad_proj).sum().backward()
+    assert test_inputs.grad is not None
+    full_grad = test_inputs.grad.clone()
+    test_inputs.grad.data.zero_()
+
+    first_half = sequential[: config.n_layer // 2]
+    second_half = sequential[config.n_layer // 2 :]
+    assert len(first_half) + len(second_half) == len(sequential)
+    assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+    for m in sequential, first_half, second_half:
+        assert isinstance(repr(m), str)
+
+    hidden = first_half(test_inputs)
+    assert isinstance(hidden, torch.Tensor)
+    assert hidden.shape == test_inputs.shape
+    assert hidden.requires_grad
+    second_half_outputs = second_half(hidden)
+    assert torch.allclose(second_half_outputs, full_outputs)
+
+    (second_half_outputs * grad_proj).sum().backward()
+    assert torch.allclose(test_inputs.grad, full_grad)

+ 13 - 0
tests/test_utils.py

@@ -0,0 +1,13 @@
+import os
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")