Selaa lähdekoodia

Implement RemoteSequential slicing and extra repr, add tests (#30)

- finish renaming RemoteSequenceInfo -> RemoteSequenceManager (why: if it was an *Info, user would expect it to be similar - to a dataclass; whereas in actuality, the class is doing heavy network interactions on its own)
- implement RemoteSequenceManager.make_sequence (from https://pastebin.com/uXgy2U8B )
- make RemoteSequentialInferenceSession use RemoteSequenceManager.make_sequence
- make tests pass again
- make it possible to create inference session without RemoteTransformerBlock
- make a standalone test for RemoteSequential
- rollback convert-model

Co-authored-by: Tim Dettmers <tim.dettmers@gmail.com>
justheuristic 3 vuotta sitten
vanhempi
commit
f0c7383181

+ 3 - 8
.github/workflows/run-tests.yaml

@@ -66,6 +66,8 @@ jobs:
         run: |
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
+          export REF_NAME=bigscience/bloom-350m
+
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
           SERVER1_PID=$!
@@ -79,14 +81,7 @@ jobs:
 
           sleep 30  # wait for server to download layers
           
-          # test individual blocks
-          export PYTHONPATH=.
-          BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
-          BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
-
-          REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
-          
-          REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
+          PYTHONPATH=. pytest tests
           
           kill -s SIGINT $SERVER1_PID $SERVER2_PID
           echo "Done!"

+ 2 - 1
src/client/__init__.py

@@ -1,4 +1,5 @@
-from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
+from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
+from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_sequential import RemoteSequential
 from src.client.sequence_manager import RemoteSequenceManager

+ 173 - 0
src/client/inference_session.py

@@ -0,0 +1,173 @@
+from __future__ import annotations
+
+import asyncio
+import contextlib
+from typing import AsyncIterator, List, Optional
+
+import torch
+from hivemind import (
+    P2P,
+    anext,
+    deserialize_torch_tensor,
+    get_logger,
+    nested_flatten,
+    serialize_torch_tensor,
+    use_hivemind_log_handler,
+)
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import StubBase
+from hivemind.proto import runtime_pb2
+
+from src.client.sequence_manager import RemoteSequenceManager
+from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
+from src.server.handler import TransformerConnectionHandler
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteTransformerBlockInferenceSession:
+    """
+    An interface to a single multi-step *inference* session for a specific remote module on a specific server
+
+    :note: this inference session is *not* fault-tolerant out of the box
+    """
+
+    def __init__(self, uid: ModuleUID, rpc_info: RPCInfo, inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
+        self.uid, self.rpc_info = uid, rpc_info
+        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
+        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
+        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
+        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.stepped = False
+        self.closed = False
+
+    @classmethod
+    async def _create(
+        cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None
+    ) -> RemoteTransformerBlockInferenceSession:
+        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
+        inputs_queue = asyncio.Queue()
+        outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
+        return cls(uid, rpc_info, inputs_queue, outputs_stream)
+
+    @staticmethod
+    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
+        while True:
+            next_input_message = await asyncio.wait_for(queue.get(), timeout)
+            yield next_input_message
+            if not next_input_message.uid and not next_input_message.tensors:
+                break  # this message means "done sending"
+
+    def step(self, new_hidden_states: torch.Tensor):
+        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
+        if self.closed:
+            raise Exception("Session is closed, cannot perform step")
+        # serialize inputs and put them into the queue
+        inputs = (new_hidden_states,)
+        outputs_serialized = RemoteExpertWorker.run_coroutine(
+            self._step(
+                runtime_pb2.ExpertRequest(
+                    uid=self.uid,
+                    tensors=[
+                        serialize_torch_tensor(tensor, proto.compression)
+                        for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["forward_schema"]))
+                    ],
+                )
+            )
+        )
+        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
+        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
+        return outputs[0]
+
+    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
+        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
+        await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
+        return await anext(self._outputs_stream)
+
+    def close(self):
+        """Finish a given inference session, close the underlying connection"""
+        if self._outputs_stream is None:
+            return  # already closed
+        RemoteExpertWorker.run_coroutine(self._aclose_stream())
+        self._outputs_stream = self._inputs_queue = None
+        self.closed = True
+
+    async def _aclose_stream(self):
+        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
+        if self._outputs_stream is None:
+            return  # already closed
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
+
+    def __del__(self):
+        self.close()
+
+    def __enter__(self):
+        assert not self.closed
+        return self
+
+    def __exit__(self, *exc_details):
+        self.close()
+
+
+class RemoteSequentialInferenceSession:
+    """
+    An interface to a multi-step *inference* session for a sequence of remote transformer blocks
+    """
+
+    def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None):
+        self.sequence_manager = sequence_manager
+        self.p2p = p2p
+        self.closed = False
+        self.chosen_spans: List[RemoteSpanInfo] = []
+        self.stack = contextlib.ExitStack()
+        self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
+        self.timeout = timeout
+
+    def __enter__(self):
+        assert not self.closed and not self.chosen_spans
+        self.stack.__enter__()
+        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+        self.chosen_spans.extend(self.sequence_manager.make_sequence())
+
+        for chosen_span in self.chosen_spans:
+            stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
+            span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
+            inference_session = RemoteExpertWorker.run_coroutine(
+                RemoteTransformerBlockInferenceSession._create(
+                    stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout
+                )
+            )
+            self.inference_sessions.append(inference_session)
+            self.stack.enter_context(inference_session)
+
+        return self
+
+    def step(self, inputs: torch.Tensor):
+        assert not self.closed
+        if torch.is_grad_enabled():
+            logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
+        for session in self.inference_sessions:
+            outputs = session.step(inputs)
+            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+            inputs = outputs
+        return inputs
+
+    def close(self, *exc_details):
+        """Finish a given inference session, close the underlying connection"""
+        if not self.closed:
+            self.stack.__exit__(*exc_details or (None, None, None))
+            self.inference_sessions.clear()
+            self.closed = True
+
+    def __exit__(self, *exc_details):
+        self.close(*exc_details)
+
+    def __del__(self):
+        self.close()

+ 5 - 99
src/client/remote_block.py

@@ -1,20 +1,16 @@
 # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 from __future__ import annotations
 
-import asyncio
 import random
-from typing import Any, AsyncIterator, Dict, Optional
 
 import torch
-from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
-from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, get_logger, nested_flatten, use_hivemind_log_handler
+from hivemind.utils import get_logger, use_hivemind_log_handler
 
+from src.client.inference_session import RemoteTransformerBlockInferenceSession
 from src.data_structures import RemoteModuleInfo
-from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
@@ -39,100 +35,10 @@ class RemoteTransformerBlock(RemoteExpert):
 
     def inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
-        _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
-        return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
+        return RemoteExpertWorker.run_coroutine(
+            RemoteTransformerBlockInferenceSession._create(self.stub, self.uid, self.info)
+        )
 
     def begin_inference_session(self):
         logger.warning("beging_inference_session was renamed to just inference_session")
         return self.inference_session()
-
-
-class RemoteTransformerBlockInferenceSession:
-    """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
-
-    def __init__(self, uid: ModuleUID, info: Dict[str, Any], inputs_queue: asyncio.Queue, outputs_aiter: AsyncIterator):
-        self.uid, self.info = uid, info
-        # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
-        # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
-        self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
-        self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
-        self.stepped = False
-        self.closed = False
-
-    @classmethod
-    async def _create(
-        cls,
-        remote_module: RemoteTransformerBlock,
-        timeout: Optional[float] = None,
-    ) -> RemoteTransformerBlockInferenceSession:
-        """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
-        inputs_queue = asyncio.Queue()
-        outputs_stream = await remote_module.stub.rpc_inference(
-            cls._read_inputs_from_queue(inputs_queue, timeout),
-            timeout=timeout,
-        )
-        return cls(remote_module.uid, remote_module.info, inputs_queue, outputs_stream)
-
-    @staticmethod
-    async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
-        while True:
-            next_input_message = await asyncio.wait_for(queue.get(), timeout)
-            yield next_input_message
-            if not next_input_message.uid and not next_input_message.tensors:
-                break  # this message means "done sending"
-
-    def step(self, new_hidden_states: torch.Tensor):
-        """Inference step: send a chunk of input tensors and receive a chunk of outputs"""
-        if self.closed:
-            raise Exception("Session is closed, cannot perform step")
-        # serialize inputs and put them into the queue
-        inputs = (new_hidden_states,)
-        outputs_serialized = RemoteExpertWorker.run_coroutine(
-            self._step(
-                runtime_pb2.ExpertRequest(
-                    uid=self.uid,
-                    tensors=[
-                        serialize_torch_tensor(tensor, proto.compression)
-                        for tensor, proto in zip(inputs, nested_flatten(self.info["forward_schema"]))
-                    ],
-                )
-            )
-        )
-        outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
-        assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
-        return outputs[0]
-
-    async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
-        """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
-        await self._inputs_queue.put(inputs_serialized)
-        self.stepped = True
-        return await anext(self._outputs_stream)
-
-    def close(self):
-        """Finish a given inference session, close the underlying connection"""
-        if self._outputs_stream is None:
-            return  # already closed
-        RemoteExpertWorker.run_coroutine(self._aclose_stream())
-        self._outputs_stream = self._inputs_queue = None
-        self.closed = True
-
-    async def _aclose_stream(self):
-        """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
-        if self._outputs_stream is None:
-            return  # already closed
-        if self.stepped:
-            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-            try:
-                await anext(self._outputs_stream)
-            except StopAsyncIteration:
-                pass
-
-    def __del__(self):
-        self.close()
-
-    def __enter__(self):
-        assert not self.closed
-        return self
-
-    def __exit__(self, *exc_details):
-        self.close()

+ 16 - 78
src/client/remote_sequential.py

@@ -1,17 +1,15 @@
 from __future__ import annotations
 
-import contextlib
 import logging
-import random
 from typing import Optional, Union
 
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
-from hivemind.moe.expert_uid import ExpertInfo
 from torch import nn
 
 import src
+from src.client.inference_session import RemoteSequentialInferenceSession
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import UID_DELIMITER
@@ -30,49 +28,41 @@ class RemoteSequential(nn.Module):
         self,
         config: src.DistributedBloomConfig,
         dht: DHT,
-        prefix: str,
-        max_retries: int = 3,
+        dht_prefix: Optional[str] = None,
         p2p: Optional[P2P] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
     ):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
-        if prefix.endswith(UID_DELIMITER):
-            logger.warning(
-                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
-                f"This will cause {self.__class__.__name__} to look for modules under "
-                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
-            )
-
         super().__init__()
         self.config = config
         self.dht = dht
-        self.prefix = prefix
-        self.max_retries = max_retries
+        self.dht_prefix = dht_prefix or config.dht_prefix
         self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
 
-        block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+        num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
+        block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
         if sequence_manager is None:
             logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
-            self.sequence_manager = RemoteSequenceManager(dht, block_uids)
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
             self.is_subsequence = False
         else:
+            logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
+            self.sequence_manager = sequence_manager
             assert isinstance(sequence_manager.block_uids, list)
-            logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
             self.is_subsequence = self.sequence_manager.block_uids == block_uids
 
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
-        for block_index in range(self.config.n_layer):
-            for retry_index in range(self.max_retries):
+        for block in iter(self):
+            for retry_index in range(self.sequence_manager.max_retries):
                 try:
-                    block = self[block_index]
                     (outputs,) = block(inputs)
                     assert isinstance(outputs, torch.Tensor)
                     assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
                     inputs = outputs
                     break
                 except Exception as e:
-                    if retry_index == self.max_retries - 1:
+                    if retry_index == self.sequence_manager.max_retries - 1:
                         raise e
                     else:
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
@@ -81,21 +71,20 @@ class RemoteSequential(nn.Module):
     def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
         assert isinstance(ix, (int, slice))
         if isinstance(ix, int):
-            assert 0 <= ix < self.config.n_layer
+            assert 0 <= ix < len(self)
             (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
             return module
         else:
             return RemoteSequential(
                 self.config,
                 self.dht,
-                prefix=self.prefix,
-                max_retries=self.max_retries,
+                dht_prefix=self.dht_prefix,
                 p2p=self.p2p,
                 sequence_manager=self.sequence_manager[ix],
             )
 
     def __iter__(self):
-        for block_index in range(self.config.n_layer):
+        for block_index in range(len(self)):
             yield self[block_index]
 
     def __len__(self):
@@ -105,56 +94,5 @@ class RemoteSequential(nn.Module):
         self.sequence_manager.update_()
         return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
 
-
-class RemoteSequentialInferenceSession:
-    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
-
-    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
-        self.remote_sequence_info = remote_sequence_info
-        self.p2p = p2p
-        self.closed = False
-        self.stack = contextlib.ExitStack()
-        self.active_sessions = []
-
-    def __enter__(self):
-        assert not self.closed
-        self.stack.__enter__()
-        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
-        current_block = 0
-        while current_block != len(self.remote_sequence_info):
-            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
-            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
-            assert chosen_span.start <= current_block < chosen_span.end
-
-            # TODO begin throwaway prototype code
-            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
-            _ = remote.info  # TODO fix
-            span_uids = self.remote_sequence_info.block_uids[current_block : chosen_span.end]
-            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
-            self.active_sessions.append(remote.inference_session())
-            self.stack.enter_context(self.active_sessions[-1])
-            current_block = chosen_span.end
-            # TODO end throwaway prototype code
-
-        return self
-
-    def step(self, inputs: torch.Tensor):
-        assert not self.closed
-        for session in self.active_sessions:
-            outputs = session.step(inputs)
-            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
-            inputs = outputs
-        return inputs
-
-    def close(self, *exc_details):
-        """Finish a given inference session, close the underlying connection"""
-        if not self.closed:
-            self.stack.__exit__(*exc_details or (None, None, None))
-            self.active_sessions.clear()
-            self.closed = True
-
-    def __exit__(self, *exc_details):
-        self.close(*exc_details)
-
-    def __del__(self):
-        self.close()
+    def extra_repr(self) -> str:
+        return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 62 - 19
src/client/sequence_manager.py

@@ -1,36 +1,37 @@
 from __future__ import annotations
 
+import random
 import threading
 from typing import List, Optional, Sequence, Tuple, Union
 
-from hivemind import DHT, DHTExpiration
+from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.proto import runtime_pb2
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
+from src.server.handler import TransformerConnectionHandler
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
 class RemoteSequenceManager:
-    """Keeps and updates the meta-information about which peers host which blocks"""
-
-    dht: DHT
-    block_uids: List[ModuleUID]
-    block_infos: List[Optional[RemoteModuleInfo]]
-    spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
-    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
-    last_update_time: DHTExpiration
-    lock_changes: threading.Lock
-
-    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
-        self.dht = dht
-        self.block_uids = list(block_uids)
-        self.block_infos = [None] * len(self.block_uids)
-        self.spans_by_priority = []
-        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
-        self.last_update_time = -float("inf")
+    """
+    Keeps and updates the meta-information about which peers host which blocks.
+    In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
+    """
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
+        self.dht, self.p2p = dht, p2p
+        self.block_uids: List[ModuleUID] = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
+        self.spans_by_priority: List[RemoteSpanInfo] = []  # sorted from best to worst
+        self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
+        self.last_update_time: DHTExpiration = -float("inf")
+        self.max_retries = max_retries
+        self._rpc_info = None
         self.lock_changes = threading.Lock()
         self.update_()
 
@@ -38,13 +39,33 @@ class RemoteSequenceManager:
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
 
+    def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> Sequence[RemoteSpanInfo]:
+        """
+        Form a sequence of remote servers that collectively serve all consecutive layers
+
+        :param start_index: optional index of the first module in a sequence, default = the first of block_uids
+        :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
+        """
+        end_index = end_index if end_index is not None else len(self.block_uids)
+        span_sequence = []
+        current_index = start_index
+        while current_index < end_index:
+            candidate_spans = self.spans_containing_block[current_index]
+            chosen_span = random.choice(candidate_spans)  # TODO this should be replaced with proper load balancing
+
+            assert chosen_span.start <= current_index < chosen_span.end
+            span_sequence.append(chosen_span)
+            current_index = chosen_span.end
+
+        return span_sequence
+
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
         """Get a RemoteSequenceManager for a sub-sequence of blocks"""
         assert isinstance(ix, (int, slice))
         if not isinstance(ix, slice):
             ix = slice(int(ix), int(ix) + 1, 1)
         with self.lock_changes:
-            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
             subseq.block_infos = self.block_infos[ix]
             subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
             subseq.last_update_time = self.last_update_time
@@ -102,3 +123,25 @@ class RemoteSequenceManager:
 
     def __len__(self):
         return len(self.block_uids)
+
+    @property
+    def rpc_info(self):
+        """Return the rpc_info queried from one of the servers that hold the first block"""
+        if self._rpc_info is None:
+            retries = 0
+            for i in range(self.max_retries):
+                try:
+                    self.update_()
+                    peer_id = random.choice(list(self.block_infos[0].servers.keys()))
+                    stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
+                    outputs = RemoteExpertWorker.run_coroutine(
+                        stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
+                    )
+                    self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
+                except Exception as e:
+                    retries += 1
+                    if retries >= self.max_retries:
+                        raise e
+                    else:
+                        logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
+        return self._rpc_info

+ 4 - 1
src/data_structures.py

@@ -1,6 +1,6 @@
 from dataclasses import dataclass
 from enum import Enum
-from typing import Dict
+from typing import Any, Dict
 
 from hivemind import PeerID
 
@@ -36,3 +36,6 @@ class RemoteSpanInfo:
     start: int
     end: int
     peer_id: PeerID
+
+
+RPCInfo = Dict[str, Any]

+ 0 - 1
src/server/handler.py

@@ -1,4 +1,3 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 import contextlib
 from typing import AsyncIterator, Dict, Sequence
 

+ 51 - 0
tests/conftest.py

@@ -0,0 +1,51 @@
+import asyncio
+import gc
+from contextlib import suppress
+
+import psutil
+import pytest
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.mpfuture import MPFuture
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+@pytest.fixture
+def event_loop():
+    """
+    This overrides the ``event_loop`` fixture from pytest-asyncio
+    (e.g. to make it compatible with ``asyncio.subprocess``).
+
+    This fixture is identical to the original one but does not call ``loop.close()`` in the end.
+    Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
+    However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
+    For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
+    fails if the loop is closed, but works if the loop is only stopped).
+    """
+
+    yield asyncio.get_event_loop()
+
+
+@pytest.fixture(autouse=True, scope="session")
+def cleanup_children():
+    yield
+
+    with RSAPrivateKey._process_wide_key_lock:
+        RSAPrivateKey._process_wide_key = None
+
+    gc.collect()  # Call .__del__() for removed objects
+
+    children = psutil.Process().children(recursive=True)
+    if children:
+        logger.info(f"Cleaning up {len(children)} leftover child processes")
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.terminate()
+        psutil.wait_procs(children, timeout=1)
+        for child in children:
+            with suppress(psutil.NoSuchProcess):
+                child.kill()
+
+    MPFuture.reset_backend()

+ 22 - 30
tests/test_block_exact_match.py

@@ -1,47 +1,39 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
+import random
 
 import hivemind
+import pytest
 import torch
 import transformers
+from test_utils import *
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
-
 
+@pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
 
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-    ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
+    for block_index in random.sample(range(config.n_layer), 3):
+        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
+        remote_block = get_remote_module(dht, block_uid)
+        assert remote_block is not None, f"Could not find {block_uid} in DHT"
+        assert isinstance(remote_block, RemoteTransformerBlock)
 
-    inputs = torch.randn(1, 8, ref_config.hidden_size)
-    (outputs_forward,) = remote_block(inputs)
+        inputs = torch.randn(1, 8, config.hidden_size)
+        (outputs_forward,) = remote_block(inputs)
 
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
+        outputs_inference = []
+        with remote_block.inference_session() as sess:
+            for i in range(inputs.shape[1]):
+                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+        outputs_inference = torch.cat(outputs_inference, dim=1)
 
-    ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
-    (outputs_local,) = ref_block(inputs)
+        ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+        (outputs_local,) = ref_block(inputs)
 
-    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
-    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+        assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
+        assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 9 - 18
tests/test_chained_calls.py

@@ -3,30 +3,20 @@
 # - if you want more stable tests, see test_block_exact_match
 # - if you want to figure out chained inference, ask yozh
 
-import os
 
 import hivemind
+import pytest
 import torch
 import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
+from test_utils import *
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.dht_utils import get_remote_module
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
 
+@pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@@ -38,9 +28,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
 
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
     outputs_rpc = remote_block.forward(inputs)[0]
@@ -59,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
 
 
+@pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@@ -78,8 +69,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     outputs_inference = torch.cat(outputs_inference, dim=1)
 
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
     ]
     outputs_ref = []
     caches = [None, None]

+ 24 - 33
tests/test_full_model.py

@@ -1,9 +1,8 @@
-# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
-import os
-
+import pytest
 import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
+from test_utils import *
 
 from src.client.remote_model import DistributedBloomForCausalLM
 
@@ -11,19 +10,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-MODEL_NAME = os.environ.get("MODEL_NAME")
-if not MODEL_NAME:
-    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME")
-
-
+@pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
@@ -31,23 +18,12 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    parallel_outputs = model.forward(test_inputs).logits
-    assert torch.all(torch.isfinite(parallel_outputs))
-    logger.info("Forward outputs are finite")
 
-    if REF_NAME:
-        with torch.no_grad():
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-            # note: this creates a dummy mask to make the test compatible with older transformer versions
-            # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-            del ref_model, ref_outputs
-    else:
-        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+    with torch.no_grad():
+        parallel_outputs = model.forward(test_inputs).logits
+        assert torch.all(torch.isfinite(parallel_outputs))
+        logger.info("Forward outputs are finite")
 
-    with torch.inference_mode():
         embs = model.transformer.word_embeddings(test_inputs)
         embs = model.transformer.word_embeddings_layernorm(embs)
         recurrent_outputs = []
@@ -60,5 +36,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         dictionary = model.transformer.word_embeddings.weight.t()
         recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
         recurrent_outputs = (recurrent_outputs @ dictionary).float()
-    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-    logger.info("Inference is consistent with forward")
+        assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+        logger.info("Inference is consistent with forward")
+
+        del model, recurrent_outputs
+
+        if REF_NAME:
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+            # note: this creates a dummy mask to make the test compatible with older transformer versions
+            # prior to https://github.com/huggingface/transformers/pull/17837
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+            logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
+            del ref_model, ref_outputs, dummy_mask
+        else:
+            logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+            assert False

+ 43 - 0
tests/test_remote_sequential.py

@@ -0,0 +1,43 @@
+import pytest
+import torch
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+from test_utils import *
+
+from src import RemoteSequential
+from src.client.remote_model import DistributedBloomConfig
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+@pytest.mark.forked
+def test_remote_sequential():
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
+    test_inputs = torch.randn(1, 5, config.hidden_size, requires_grad=True)
+    grad_proj = torch.randn(1, 5, config.hidden_size)
+
+    sequential = RemoteSequential(config, dht)
+
+    full_outputs = sequential(test_inputs)
+    (full_outputs * grad_proj).sum().backward()
+    assert test_inputs.grad is not None
+    full_grad = test_inputs.grad.clone()
+    test_inputs.grad.data.zero_()
+
+    first_half = sequential[: config.n_layer // 2]
+    second_half = sequential[config.n_layer // 2 :]
+    assert len(first_half) + len(second_half) == len(sequential)
+    assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+    for m in sequential, first_half, second_half:
+        assert isinstance(repr(m), str)
+
+    hidden = first_half(test_inputs)
+    assert isinstance(hidden, torch.Tensor)
+    assert hidden.shape == test_inputs.shape
+    assert hidden.requires_grad
+    second_half_outputs = second_half(hidden)
+    assert torch.allclose(second_half_outputs, full_outputs)
+
+    (second_half_outputs * grad_proj).sum().backward()
+    assert torch.allclose(test_inputs.grad, full_grad)

+ 13 - 0
tests/test_utils.py

@@ -0,0 +1,13 @@
+import os
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")