Răsfoiți Sursa

Merge branch 'client' into main

justheuristic 3 ani în urmă
părinte
comite
d42e8abd38

+ 4 - 2
README.md

@@ -65,7 +65,7 @@ loss = (outputs * torch.randn_like(outputs)).norm()
 loss.backward()
 
 # test inference, one block
-with layer3.begin_inference_session() as sess:
+with layer3.inference_session() as sess:
     for i in range(10):
         res = sess.step(torch.ones(1, 1, 4096))
 ```
@@ -94,7 +94,9 @@ python -m cli.run_server --prefix bloom6b3 --converted_model_name_or_path bigsci
 export PYTHONPATH=. INITIAL_PEERS="/ip4/TODO_COPY_INITIAL_PEERS_FROM_SERVER_OUTPUT"
 BLOCK_UID=bloom6b3.3 pytest tests/test_block_exact_match.py
 BLOCK_UID=bloom6b3.4 pytest tests/test_block_exact_match.py
-
 # the test below will fail because there is no server that serves layer 7
 # BLOCK_UID=bloom6b3.7 pytest tests/test_block_exact_match.py
+
+# test full model exact match
+MODEL_NAME=bigscience/test-bloomd-6b3 REF_NAME=bigscience/bloom-6b3 pytest tests/test_full_model.py
 ```

+ 2 - 1
cli/run_server.py

@@ -14,11 +14,12 @@ def main():
     parser = configargparse.ArgParser(default_config_files=["config.yml"])
     parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
 
-    parser.add_argument('--prefix', type=str, required=True, help="Announce all blocks with this prefix")
     parser.add_argument('--converted_model_name_or_path', type=str, default='bigscience/test-bloomd-6b3',
                         help="path or name of a pretrained model, converted with cli/convert_model.py (see README.md)")
     parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
     parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
+    parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
+                                                                 "use the same name as in the converted model.")
     parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
                         help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
     parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,

+ 2 - 9
src/bloom/block.py

@@ -9,15 +9,8 @@ import torch
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 
-from src.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    build_alibi_tensor,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-)
+from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
+                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
 
 
 class BloomAttention(nn.Module):

+ 7 - 9
src/bloom/model.py

@@ -11,11 +11,8 @@ import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
+from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
+                                     add_start_docstrings_to_model_forward)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
@@ -208,6 +205,8 @@ class BloomModel(BloomPreTrainedModel):
 
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        if position_ids is not None:
+            logger.warning("position_ids are ignored in this bloom implementation")
         elif input_ids is not None:
             input_shape = input_ids.size()
             input_ids = input_ids.view(-1, input_shape[-1])
@@ -238,9 +237,8 @@ class BloomModel(BloomPreTrainedModel):
 
         # Compute alibi tensor: check build_alibi_tensor documentation
         current_sequence_length = hidden_states.shape[1]
-        if past_key_values[0] is not None:
+        if past_key_values and past_key_values[0]:
             current_sequence_length += past_key_values[0][0].shape[1]
-        alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
 
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
@@ -258,7 +256,7 @@ class BloomModel(BloomPreTrainedModel):
                 def create_custom_forward(module):
                     def custom_forward(*inputs):
                         # None for past_key_value
-                        return module(*inputs, use_cache, output_attentions, alibi)
+                        return module(*inputs, use_cache, output_attentions, alibi=None)
 
                     return custom_forward
 
@@ -277,7 +275,7 @@ class BloomModel(BloomPreTrainedModel):
                     head_mask=head_mask[i],
                     use_cache=use_cache,
                     output_attentions=output_attentions,
-                    alibi=alibi,
+                    alibi=None,
                 )
 
             hidden_states = outputs[0]

+ 23 - 7
src/client/remote_block.py

@@ -11,13 +11,17 @@ from hivemind.moe.client.expert import RemoteExpert, RemoteExpertWorker
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.p2p import P2P, StubBase
 from hivemind.proto import runtime_pb2
-from hivemind.utils import anext, nested_flatten
+from hivemind.utils import anext, nested_flatten, use_hivemind_log_handler, get_logger
 
 from src.data_structures import RemoteModuleInfo
 from src.dht_utils import ModuleUID
 from src.server.handler import TransformerConnectionHandler
 
 
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
 class RemoteTransformerBlock(RemoteExpert):
     """A class that interacts with a remote module on a specific server for forward/backward or inference"""
 
@@ -29,11 +33,20 @@ class RemoteTransformerBlock(RemoteExpert):
     def stub(self) -> StubBase:
         return TransformerConnectionHandler.get_stub(self.p2p, self.peer_id)
 
-    def begin_inference_session(self) -> RemoteTransformerBlockInferenceSession:
+    def forward(self, inputs: torch.Tensor, **kwargs):
+        for k, v in kwargs.items():
+            assert v is None or v is False, f"Extra keyword arguments are not yet supported (got {k} = {v})"
+        return super().forward(inputs)
+
+    def inference_session(self) -> RemoteTransformerBlockInferenceSession:
         """Initialize a new inference session with the specified remote server"""
         _ = self.info  # create _info manually since the built-in property will not work inside RemoteExpertWorker
         return RemoteExpertWorker.run_coroutine(RemoteTransformerBlockInferenceSession._create(self))
 
+    def begin_inference_session(self):
+        logger.warning("beging_inference_session was renamed to just inference_session")
+        return self.inference_session()
+
 
 class RemoteTransformerBlockInferenceSession:
     """An interface to a single multi-step *inference* session for a specific remote module with a specific server"""
@@ -44,6 +57,7 @@ class RemoteTransformerBlockInferenceSession:
         # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
         self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
         self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
+        self.stepped = False
         self.closed = False
 
     @classmethod
@@ -89,6 +103,7 @@ class RemoteTransformerBlockInferenceSession:
     async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
         """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
         await self._inputs_queue.put(inputs_serialized)
+        self.stepped = True
         return await anext(self._outputs_stream)
 
     def close(self):
@@ -103,11 +118,12 @@ class RemoteTransformerBlockInferenceSession:
         """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
         if self._outputs_stream is None:
             return  # already closed
-        await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
-        try:
-            await anext(self._outputs_stream)
-        except StopAsyncIteration:
-            pass
+        if self.stepped:
+            await self._inputs_queue.put(runtime_pb2.ExpertRequest())  # empty request will trigger end of session
+            try:
+                await anext(self._outputs_stream)
+            except StopAsyncIteration:
+                pass
 
     def __del__(self):
         self.close()

+ 49 - 0
src/client/remote_model.py

@@ -0,0 +1,49 @@
+# this code is in active development, interfaces may change
+import os
+from typing import Optional, Union
+
+import hivemind
+from hivemind import DHT, get_logger, use_hivemind_log_handler
+
+from src.bloom import BloomForCausalLM, DistributedBloomConfig
+from src.bloom.from_pretrained import CLIENT_BRANCH, _load_state_dict
+from src.client.remote_sequential import RemoteSequential
+from src.data_structures import UID_DELIMITER
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class DistributedBloomForCausalLM(BloomForCausalLM):
+    """BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str):
+        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
+        super().__init__(config)
+        assert len(self.transformer.h) == 0
+        config.n_layer = n_layer
+        self.transformer.h = RemoteSequential(config, dht, prefix)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
+        if 'initial_peers' not in kwargs:
+            raise ValueError("Please specify initial_peers=...")
+        dht = hivemind.DHT(
+            initial_peers=kwargs.pop('initial_peers'), client_mode=kwargs.pop('client_mode', True),
+            start=True)
+
+        if 'prefix' not in kwargs:
+            logger.debug(f"No DHT prefix specified; using automatic prefix {pretrained_model_name_or_path}")
+            assert UID_DELIMITER not in pretrained_model_name_or_path, \
+                f"Cannot infer prefix automatically from {pretrained_model_name_or_path}; please specify prefix=..."
+        prefix = kwargs.pop("prefix", pretrained_model_name_or_path)
+
+        config = DistributedBloomConfig.from_pretrained(pretrained_model_name_or_path, revision=CLIENT_BRANCH, **kwargs)
+        model = cls(config, dht, prefix)
+        model.load_state_dict(_load_state_dict(
+            pretrained_model_name_or_path, use_auth_token=kwargs.get('use_auth_token')
+        ), strict=True)
+        return model
+
+
+

+ 93 - 0
src/client/remote_sequence_info.py

@@ -0,0 +1,93 @@
+from __future__ import annotations
+
+import dataclasses
+import threading
+from functools import partial
+from typing import Tuple, List, Optional, Sequence, NamedTuple
+
+from hivemind import DHT, PeerID
+from hivemind.utils.logging import use_hivemind_log_handler, get_logger
+
+from src.data_structures import ModuleUID, RemoteModuleInfo
+from src.dht_utils import _get_remote_module_infos
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+Span = NamedTuple('Span', [('start', int), ('end', Optional[int]), ('peer_id', PeerID)])
+
+
+@dataclasses.dataclass(frozen=False, init=False)
+class RemoteSequenceInfo:
+    """Keeps and updates the meta-information about which peers host which blocks"""
+    dht: DHT
+    block_uids: List[ModuleUID, ...]
+    block_infos: List[Optional[RemoteModuleInfo], ...]
+    spans_by_priority: List[Span]  # sorted from best to worst
+    spans_containing_block: Tuple[List[Span], ...]
+    lock_changes: threading.Lock
+
+    def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
+        self.dht = dht
+        self.block_uids = list(block_uids)
+        self.block_infos: List[Optional[RemoteModuleInfo], ...] = [None] * len(self.block_uids)
+        self.spans_by_priority = []
+        self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
+        self.lock_changes = threading.Lock()
+        self.update_()
+
+        for uid, info in zip(self.block_uids, self.block_infos):
+            assert info is not None, f"Found no remote peers for block {uid}"
+        assert self.spans_by_priority and self.spans_containing_block
+
+    def update_(self):
+        with self.lock_changes:
+            self.update_block_infos_()
+            self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
+
+    def update_block_infos_(self):
+        new_block_infos: Sequence[RemoteModuleInfo] = self.dht.run_coroutine(
+            partial(_get_remote_module_infos, uids=self.block_uids, expiration_time=float("inf")),
+            return_future=False)
+        assert len(new_block_infos) == len(self.block_uids)
+        for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
+            if info is None:
+                logger.warning(f"Found no block info for block {uid}")
+            if not isinstance(info, RemoteModuleInfo):
+                logger.warning(f"Unexpected dht entry type for {uid}: {info}")
+            if not info.peer_ids:
+                logger.warning(f"Found no active peers for block {uid}")
+            if info.uid != uid:
+                logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
+            if not isinstance(info.peer_ids, set):
+                logger.warning(f"Expected peer_ids for {uid} to be a set, got {type(info.peer_ids)}")
+            self.block_infos[block_index] = info
+
+    @staticmethod
+    def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
+        closed_spans = []
+        active_spans = {}
+        for block_index, info in enumerate(block_infos):
+            for peer_id in info.peer_ids:
+                if peer_id not in active_spans:
+                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
+                else:  # peer_id in active_spans
+                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
+
+            for peer_id in list(active_spans.keys()):
+                if peer_id not in info.peer_ids or block_index == len(block_infos) - 1:
+                    closed_spans.append(active_spans.pop(peer_id))
+        assert not active_spans
+
+        closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
+
+        spans_containing_block = tuple(list() for _ in range(len(block_infos)))
+        for span in closed_spans:
+            for block_index in range(span.start, span.end):
+                spans_containing_block[block_index].append(span)
+
+        return closed_spans, spans_containing_block
+
+    def __len__(self):
+        return len(self.block_uids)

+ 134 - 0
src/client/remote_sequential.py

@@ -0,0 +1,134 @@
+from __future__ import annotations
+
+import contextlib
+import logging
+import random
+
+import torch
+from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.moe.expert_uid import ExpertInfo
+from torch import nn
+
+from src import DistributedBloomConfig, RemoteTransformerBlock
+from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.data_structures import UID_DELIMITER
+from src.dht_utils import _create_remote_modules_from_infos
+
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+class RemoteSequential(nn.Module):
+    """
+    A sequence of transformer blocks hosted by the swarm.
+    """
+
+    def __init__(self, config: DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+        logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
+        if prefix.endswith(UID_DELIMITER):
+            logger.warning(
+                f"dht_prefix {prefix} already ends with '{UID_DELIMITER}'."
+                f"This will cause {self.__class__.__name__} to look for modules under "
+                f"{prefix}{UID_DELIMITER}*. Please make sure this is what you intended."
+            )
+
+        super().__init__()
+        self.config = config
+        self.dht = dht
+        self.prefix = prefix
+        self.max_retries = max_retries
+        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
+
+        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
+        logger.debug(f"Remote block uids: {block_uids}")
+        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
+
+    def forward(self, inputs: torch.Tensor):
+        assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
+        for block_index in range(self.config.n_layer):
+            for retry_index in range(self.max_retries):
+                try:
+                    block = self[block_index]
+                    (outputs,) = block(inputs)
+                    assert isinstance(outputs, torch.Tensor)
+                    assert outputs.shape == inputs.shape, f"Expected {block} output {inputs.shape}, got {outputs.shape}"
+                    inputs = outputs
+                    break
+                except Exception as e:
+                    if retry_index == self.max_retries - 1:
+                        raise e
+                    else:
+                        logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
+        return inputs
+
+    def __getitem__(self, block_index: int):
+        assert 0 <= block_index < self.config.n_layer
+        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
+        return module
+
+    def __iter__(self):
+        for block_index in range(self.config.n_layer):
+            yield self[block_index]
+
+    def __len__(self):
+        return len(self.remote_sequence_info)
+
+    def inference_session(self) -> RemoteSequentialInferenceSession:
+        self.remote_sequence_info.update_()
+        return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
+
+
+class RemoteSequentialInferenceSession:
+    """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
+
+    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
+        self.remote_sequence_info = remote_sequence_info
+        self.p2p = p2p
+        self.closed = False
+        self.stack = contextlib.ExitStack()
+        self.active_sessions = []
+
+    def __enter__(self):
+        assert not self.closed
+        self.stack.__enter__()
+        # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
+        current_block = 0
+        while current_block != len(self.remote_sequence_info):
+            candidate_spans = self.remote_sequence_info.spans_containing_block[current_block]
+            chosen_span = random.choice(candidate_spans)  # TODO this is a temporary code
+            assert chosen_span.start <= current_block < chosen_span.end
+
+            # TODO begin throwaway prototype code
+            remote = RemoteTransformerBlock(self.remote_sequence_info.block_infos[current_block], self.p2p)
+            _=remote.info #TODO fix
+            span_uids = self.remote_sequence_info.block_uids[current_block: chosen_span.end]
+            remote._info = ExpertInfo(" ".join(span_uids), chosen_span.peer_id)
+            self.active_sessions.append(remote.inference_session())
+            self.stack.enter_context(self.active_sessions[-1])
+            current_block = chosen_span.end
+            # TODO end throwaway prototype code
+
+        return self
+
+    def step(self, inputs: torch.Tensor):
+        assert not self.closed
+        for session in self.active_sessions:
+            outputs = session.step(inputs)
+            assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
+            inputs = outputs
+        return inputs
+
+    def close(self, *exc_details):
+        """Finish a given inference session, close the underlying connection"""
+        if not self.closed:
+            self.stack.__exit__(*exc_details or (None, None, None))
+            self.active_sessions.clear()
+            self.closed = True
+
+    def __exit__(self, *exc_details):
+        self.close(*exc_details)
+
+    def __del__(self):
+        self.close()

+ 2 - 1
src/dht_utils.py

@@ -106,7 +106,8 @@ async def _get_remote_module_infos(
     for i, uid in enumerate(uids):
         metadata = found[uid]
         if metadata is None or not isinstance(metadata.value, dict):
-            logger.error(f"Incorrect metadata for {uid}: {metadata}")
+            if metadata is not None:
+                logger.error(f"Incorrect metadata for {uid}: {metadata}")
             continue
         valid_entries = set()
         for maybe_peer_id, _unused_value in metadata.value.items():

+ 23 - 23
src/server/backend.py

@@ -26,29 +26,29 @@ class TransformerBackend(ModuleBackend):
         self.inference_pool = TaskPool(self.inference_step, max_batch_size=1, name=f"{self.name}_inference")
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
-        attention_cache_handle = int(cache_metadata[0, 0].item())
-        prefix_length = int(cache_metadata[0, 1].item())
-        hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
-        assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
-
-        with self.memory_cache.use_cache(attention_cache_handle) as cache:
-            print("METADATA:", cache_metadata)
-            assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
-            layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-            print("PAST", past_k.shape, past_v.shape)
-            hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
-
-            # todo remove these asserts once we pass all tests
-            new_length = new_v.shape[1]
-            assert new_length > prefix_length
-            assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
-            assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
-            assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-            assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
-            assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
-            cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
-            cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
-            return (hidden_states,)
+        with torch.inference_mode():
+            attention_cache_handle = int(cache_metadata[0, 0].item())
+            prefix_length = int(cache_metadata[0, 1].item())
+            hidden_states = inputs[0]  # todo: in future, it would be best to support attention mask here
+            assert hidden_states.ndim == 3, "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
+
+            with self.memory_cache.use_cache(attention_cache_handle) as cache:
+                assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
+                print("METADATA:", cache_metadata, past_k.shape, past_v.shape)
+                hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
+
+                # todo remove these asserts once we pass all tests
+                new_length = new_v.shape[1]
+                assert new_length > prefix_length
+                assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
+                assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
+                assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
+                assert torch.allclose(new_v[:, : past_v.shape[1]], past_v)
+                assert torch.allclose(new_k[:, : past_k.shape[1]], past_k)
+                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
+                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+                return (hidden_states,)
 
     def get_pools(self) -> Sequence[TaskPool]:
         return self.forward_pool, self.backward_pool, self.inference_pool

+ 8 - 1
src/server/server.py

@@ -14,6 +14,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 from src import declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, DistributedBloomConfig, load_pretrained_block
+from src.data_structures import UID_DELIMITER, CHAIN_DELIMITER
 from src.server.backend import TransformerBackend
 from src.server.cache import MemoryCache
 from src.server.handler import TransformerConnectionHandler
@@ -84,7 +85,7 @@ class Server(threading.Thread):
     @classmethod
     def create(
         cls,
-        prefix: str,
+        prefix: Optional[str],
         converted_model_name_or_path: str,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
@@ -108,6 +109,12 @@ class Server(threading.Thread):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
+        if prefix is None:
+            prefix = converted_model_name_or_path
+            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix,\
+                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); " \
+                f"Please specify --prefix manually when starting a server"
+            logger.info(f"Automatic dht prefix: {prefix}")
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
         visible_maddrs_str = [str(a) for a in dht.get_visible_maddrs()]

+ 1 - 1
tests/test_block_exact_match.py

@@ -32,7 +32,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     (outputs_forward,) = remote_block(inputs)
 
     outputs_inference = []
-    with remote_block.begin_inference_session() as sess:
+    with remote_block.inference_session() as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)

+ 1 - 1
tests/test_chained_inference.py

@@ -39,7 +39,7 @@ def test_remote_block_exact_match(atol_inference=1e-4):
     inputs = torch.randn(1, 8, 4096)
 
     outputs_inference = []
-    with remote_block.begin_inference_session() as sess:
+    with remote_block.inference_session() as sess:
         for i in range(inputs.shape[1]):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)

+ 57 - 0
tests/test_full_model.py

@@ -0,0 +1,57 @@
+# Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
+import os
+
+import torch
+import transformers
+from hivemind import use_hivemind_log_handler, get_logger
+
+from src.client.remote_model import DistributedBloomForCausalLM
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__file__)
+
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
+
+REF_NAME = os.environ.get("REF_NAME")
+
+
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    assert len(model.transformer.h) == model.config.n_layer
+
+    test_inputs = tokenizer("A cat sat on a mat", return_tensors='pt')['input_ids']
+    parallel_outputs = model.forward(test_inputs).logits
+    assert torch.all(torch.isfinite(parallel_outputs))
+    logger.info("Forward outputs are finite")
+
+    if REF_NAME:
+        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+        # note: this creates a dummy mask to make the test compatible with older transformer versions
+        # prior to https://github.com/huggingface/transformers/pull/17837
+        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+    else:
+        logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
+
+    embs = model.transformer.word_embeddings(test_inputs)
+    embs = model.transformer.word_embeddings_layernorm(embs)
+    recurrent_outputs = []
+    with model.transformer.h.inference_session() as sess:
+        for t in range(embs.shape[1]):
+            recurrent_outputs.append(sess.step(embs[:, t: t + 1, :]))
+    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+    recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
+    recurrent_outputs = model.lm_head(recurrent_outputs)
+    assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
+    logger.info("Inference is consistent with forward")