Przeglądaj źródła

Add automated tests (#23)

This PR will run basic tests automatically on each subsequent PR

- convert a small model on every PR
- run existing tests on every PR
- enforce black / isort
- require checks on merge
- make sure tests are not flappy

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Dmitry Baranchuk <dmitrybaranchuk@gmail.com>
justheuristic 3 lat temu
rodzic
commit
e2711a033b

+ 26 - 0
.github/workflows/check-style.yaml

@@ -0,0 +1,26 @@
+name: Check style
+
+on:
+  push:
+    branches: [ master ]
+  pull_request:
+
+jobs:
+  black:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: psf/black@stable
+        with:
+          options: "--check --diff"
+          version: "22.3.0"
+  isort:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - uses: actions/setup-python@v2
+        with:
+          python-version: 3.8
+      - uses: isort/isort-action@master
+        with:
+          isortVersion: "5.10.1"

+ 89 - 0
.github/workflows/run-tests.yaml

@@ -0,0 +1,89 @@
+name: Tests
+
+on:
+  push:
+    branches: [ master ]
+  pull_request:
+
+jobs:
+  convert-model:
+    runs-on: ubuntu-latest
+    env:
+      BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
+    timeout-minutes: 15
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: 3.9
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-py3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+      - name: Delete previous model, if exists
+        run: |
+          python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
+          name='test-bloomd-350m-$GITHUB_HEAD_REF', organization='bloom-testing')" || true
+      - name: Convert model and push to hub
+        run: |
+          python -m cli.convert_model --model bigscience/bloom-350m  --output_path ./converted_model \
+            --output_repo bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
+
+
+  run-tests:
+    runs-on: ubuntu-latest
+    needs: convert-model
+    strategy:
+      matrix:
+        python-version: [ 3.7, 3.8, 3.9 ]
+      fail-fast: false
+    timeout-minutes: 15
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Test
+        run: |
+          export MODEL_NAME=bloom-testing/test-bloomd-350m-$GITHUB_HEAD_REF
+          python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
+            --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
+          SERVER1_PID=$!
+          
+          export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
+          # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
+          
+          python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:24 \
+            --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
+          SERVER2_PID=$!
+
+          sleep 30  # wait for server to download layers
+          
+          # test individual blocks
+          export PYTHONPATH=.
+          BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
+          BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
+
+          REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
+          
+          REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
+          
+          kill -s SIGINT $SERVER1_PID $SERVER2_PID
+          echo "Done!"

+ 2 - 1
cli/convert_model.py

@@ -10,8 +10,9 @@ from huggingface_hub import Repository
 from tqdm.auto import tqdm
 from tqdm.auto import tqdm
 
 
 from src import BloomModel
 from src import BloomModel
+from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
 from src.client import DistributedBloomConfig
 from src.client import DistributedBloomConfig
-from src.bloom.from_pretrained import CLIENT_BRANCH, BLOCK_BRANCH_PREFIX
+
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 

Plik diff jest za duży
+ 242 - 285
cli/speed_test.py


+ 10 - 0
pyproject.toml

@@ -0,0 +1,10 @@
+[tool.black]
+line-length = 120
+required-version = "22.3.0"
+
+[tool.isort]
+profile = "black"
+line_length = 120
+combine_as_imports = true
+combine_star = true
+known_local_folder = ["tests", "cli"]

+ 6 - 0
requirements-dev.txt

@@ -0,0 +1,6 @@
+pytest==6.2.5  # see https://github.com/pytest-dev/pytest/issues/9621
+pytest-forked
+pytest-asyncio==0.16.0
+black==22.3.0
+isort==5.10.1
+psutil

+ 6 - 0
requirements.txt

@@ -0,0 +1,6 @@
+torch==1.12.0
+accelerate==0.10.0
+huggingface-hub==0.7.0
+bitsandbytes-cuda113==0.26.0
+https://github.com/learning-at-home/hivemind/archive/d42c70331da43667da6d9020666df54806d8b561.zip
+https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip

+ 9 - 2
src/bloom/block.py

@@ -9,8 +9,15 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
 import torch.nn.quantized.dynamic.modules.linear
 
 
-from src.bloom.ops import (BloomGelu, BloomScaledSoftmax, attention_mask_func, build_alibi_tensor, dropout_add,
-                           pre_process_alibi_for_pad, split_tensor_along_last_dim)
+from src.bloom.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    build_alibi_tensor,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 
 
 class BloomAttention(nn.Module):
 class BloomAttention(nn.Module):

+ 30 - 13
src/bloom/model.py

@@ -10,14 +10,16 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 import torch.utils.checkpoint
 from hivemind import use_hivemind_log_handler
 from hivemind import use_hivemind_log_handler
 from torch import nn
 from torch import nn
-from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm
-from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
-                                     add_start_docstrings_to_model_forward)
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import (
 from transformers.modeling_outputs import (
     BaseModelOutputWithPastAndCrossAttentions,
     BaseModelOutputWithPastAndCrossAttentions,
     CausalLMOutputWithCrossAttentions,
     CausalLMOutputWithCrossAttentions,
     SequenceClassifierOutputWithPast,
     SequenceClassifierOutputWithPast,
-    TokenClassifierOutput,
 )
 )
 from transformers.modeling_utils import PreTrainedModel
 from transformers.modeling_utils import PreTrainedModel
 from transformers.models.bloom.configuration_bloom import BloomConfig
 from transformers.models.bloom.configuration_bloom import BloomConfig
@@ -445,12 +447,27 @@ class LMHead(nn.Module):
         self.word_embeddings = word_embeddings
         self.word_embeddings = word_embeddings
         self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
         self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
 
 
+    @property
+    def in_features(self) -> int:
+        return self.word_embeddings.num_embeddings
+
+    @property
+    def out_features(self) -> int:
+        return self.word_embeddings.embedding_dim
+
+    @property
+    def weight(self):
+        return self.word_embeddings.weight
+
+    @property
+    def bias(self):
+        return None
+
     def forward(self, hidden_states):
     def forward(self, hidden_states):
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
-        
+
         # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
         # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
-        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
-            word_embeddings.device.type == 'cpu':
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
             lm_logits = self.chunked_forward(hidden_states)
             lm_logits = self.chunked_forward(hidden_states)
         else:
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             # Switch dtype in case word_embeddings are fp16/bf16
@@ -459,20 +476,20 @@ class LMHead(nn.Module):
         return lm_logits
         return lm_logits
 
 
     def chunked_forward(self, hidden_states):
     def chunked_forward(self, hidden_states):
-        """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU. 
-            chunk_size: provides trade-off between efficiency and extra memory consumption. 
+        """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
+        chunk_size: provides trade-off between efficiency and extra memory consumption.
         """
         """
         assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
         assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
 
 
         word_embeddings = self.word_embeddings.weight
         word_embeddings = self.word_embeddings.weight
         num_embeddings = self.word_embeddings.num_embeddings
         num_embeddings = self.word_embeddings.num_embeddings
 
 
-        hidden_states = hidden_states.float()    
+        hidden_states = hidden_states.float()
         output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
         output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
 
 
         for i in range(0, num_embeddings, self.chunk_size):
         for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i: i + self.chunk_size].float()
-            output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
+            chunk = word_embeddings[i : i + self.chunk_size].float()
+            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
         return output
         return output
 
 
 
 
@@ -565,7 +582,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
                     f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                     f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
                     "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                     "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
                 )
                 )
-            
+
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
 
 
         loss = None
         loss = None

+ 1 - 1
src/client/__init__.py

@@ -1,4 +1,4 @@
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_block import RemoteTransformerBlock, RemoteTransformerBlockInferenceSession
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
 from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
-from src.client.remote_sequence_info import RemoteSequenceInfo
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
+from src.client.sequence_manager import RemoteSequenceManager

+ 51 - 22
src/client/remote_model.py

@@ -2,15 +2,20 @@
 import os
 import os
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
+import hivemind
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
-
-import hivemind
 from hivemind import get_logger, use_hivemind_log_handler
 from hivemind import get_logger, use_hivemind_log_handler
 
 
-from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel, LMHead, BloomForSequenceClassification
+from src.bloom.model import (
+    BloomConfig,
+    BloomForCausalLM,
+    BloomForSequenceClassification,
+    BloomModel,
+    BloomPreTrainedModel,
+    LMHead,
+)
 from src.client.remote_sequential import RemoteSequential
 from src.client.remote_sequential import RemoteSequential
-from src.data_structures import UID_DELIMITER
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
@@ -25,12 +30,13 @@ class DistributedBloomConfig(BloomConfig):
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     initial_peers: Tuple[str, ...] = ()  # a list of initial peers for hivemind DHT
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
     dht: Optional[hivemind.DHT] = None  # a running DHT instance, e.g. when using the same DHT for multiple models
-    chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
-    num_prefix_tokens: int = 0 # a number of tokens for prompt tuning. 
+    chunk_size_for_efficient_fp16_on_cpu: int = 10000  # a chunk size for a LM head for efficient half-precision on CPU
+    num_prefix_tokens: int = 0  # a number of tokens for prompt tuning.
 
 
 
 
 class DistributedBloomModel(BloomModel):
 class DistributedBloomModel(BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
     """BloomModel, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):
@@ -49,7 +55,7 @@ class DistributedBloomModel(BloomModel):
         )
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix)
         self.h = RemoteSequential(config, dht, config.dht_prefix)
-    
+
         # Forbid accumulate grads for embeddings and layernorm
         # Forbid accumulate grads for embeddings and layernorm
         self.set_requires_grad(False)
         self.set_requires_grad(False)
 
 
@@ -57,6 +63,14 @@ class DistributedBloomModel(BloomModel):
         for p in self.parameters():
         for p in self.parameters():
             p.requires_grad = value
             p.requires_grad = value
 
 
+    def forward(self, *args, use_cache=None, **kwargs):
+        if use_cache:
+            raise ValueError(
+                "Distributed forward does not support use_cache; for efficient cache-aware generation, "
+                "please use model.transformer.inference_session() or model.generate(...)"
+            )
+        return super().forward(*args, use_cache=False, **kwargs)
+
 
 
 class DistributedBloomPrefix(DistributedBloomModel):
 class DistributedBloomPrefix(DistributedBloomModel):
     """DistributedBloomModel with prefix tokens for prompt tuning"""
     """DistributedBloomModel with prefix tokens for prompt tuning"""
@@ -76,7 +90,7 @@ class DistributedBloomPrefix(DistributedBloomModel):
         return prompts
         return prompts
 
 
     def forward(
     def forward(
-        self, 
+        self,
         input_ids: Optional[torch.LongTensor],
         input_ids: Optional[torch.LongTensor],
         inputs_embeds: Optional[torch.Tensor],
         inputs_embeds: Optional[torch.Tensor],
         attention_mask: Optional[torch.Tensor],
         attention_mask: Optional[torch.Tensor],
@@ -86,14 +100,16 @@ class DistributedBloomPrefix(DistributedBloomModel):
         use_cache=None,
         use_cache=None,
         output_attentions=None,
         output_attentions=None,
         output_hidden_states=None,
         output_hidden_states=None,
-        return_dict=None
+        return_dict=None,
     ):
     ):
-        assert input_ids is None or inputs_embeds is None, "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert (
+            input_ids is None or inputs_embeds is None
+        ), "You cannot specify both input_ids and inputs_embeds at the same time"
         assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
         assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-        
+
         if inputs_embeds is None:
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
             inputs_embeds = self.word_embeddings(input_ids)
-    
+
         batch_size = inputs_embeds.shape[0]
         batch_size = inputs_embeds.shape[0]
 
 
         if attention_mask is not None:
         if attention_mask is not None:
@@ -104,25 +120,26 @@ class DistributedBloomPrefix(DistributedBloomModel):
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
         inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
 
 
         transformer_outputs = super().forward(
         transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds, 
-            attention_mask=attention_mask, 
+            inputs_embeds=inputs_embeds,
+            attention_mask=attention_mask,
             past_key_values=past_key_values,
             past_key_values=past_key_values,
             position_ids=position_ids,
             position_ids=position_ids,
             head_mask=head_mask,
             head_mask=head_mask,
             use_cache=use_cache,
             use_cache=use_cache,
             output_attentions=output_attentions,
             output_attentions=output_attentions,
             output_hidden_states=output_hidden_states,
             output_hidden_states=output_hidden_states,
-            return_dict=return_dict
+            return_dict=return_dict,
         )
         )
 
 
         # Remove prefix
         # Remove prefix
-        last_hidden_state = transformer_outputs[0][:, self.prefix_length:]
-        transformer_outputs['last_hidden_state'] = last_hidden_state
+        last_hidden_state = transformer_outputs[0][:, self.prefix_length :]
+        transformer_outputs["last_hidden_state"] = last_hidden_state
         return transformer_outputs
         return transformer_outputs
 
 
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):
 class DistributedBloomForCausalLM(BloomForCausalLM):
-    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
+    """Similar to BloomForCausalLM, but all transformer layers are hosted by the swarm"""
+
     config_class = DistributedBloomConfig
     config_class = DistributedBloomConfig
 
 
     def __init__(self, config: DistributedBloomConfig):
     def __init__(self, config: DistributedBloomConfig):
@@ -136,11 +153,23 @@ class DistributedBloomForCausalLM(BloomForCausalLM):
         # Initialize weights and apply final processing
         # Initialize weights and apply final processing
         self.post_init()
         self.post_init()
 
 
-    def get_output_embeddings(self):
-        return self.lm_head.word_embeddings
+    def get_input_embeddings(self):
+        return self.transformer.word_embeddings
 
 
-    def set_output_embeddings(self, new_embeddings):
-        self.lm_head.word_embeddings.weight = new_embeddings.weight
+    def get_output_embeddings(self):
+        if self.config.tie_word_embeddings:
+            return None
+        return self.lm_head
+
+    def set_input_embeddings(self, new_embeddings: nn.Embedding):
+        assert isinstance(new_embeddings, nn.Embedding)
+        self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
+        assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
+
+    def set_output_embeddings(self, new_lm_head: nn.Linear):
+        with torch.no_grad():
+            self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
+            self.lm_head.bias[...] = new_lm_head.bias
 
 
 
 
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
 class DistributedBloomForSequenceClassification(BloomForSequenceClassification):

+ 41 - 16
src/client/remote_sequential.py

@@ -3,6 +3,7 @@ from __future__ import annotations
 import contextlib
 import contextlib
 import logging
 import logging
 import random
 import random
+from typing import Optional, Union
 
 
 import torch
 import torch
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
 from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
@@ -12,7 +13,7 @@ from torch import nn
 
 
 import src
 import src
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
-from src.client.remote_sequence_info import RemoteSequenceInfo
+from src.client.sequence_manager import RemoteSequenceManager
 from src.data_structures import UID_DELIMITER
 from src.data_structures import UID_DELIMITER
 from src.dht_utils import _create_remote_modules_from_infos
 from src.dht_utils import _create_remote_modules_from_infos
 
 
@@ -25,7 +26,15 @@ class RemoteSequential(nn.Module):
     A sequence of transformer blocks hosted by the swarm.
     A sequence of transformer blocks hosted by the swarm.
     """
     """
 
 
-    def __init__(self, config: src.DistributedBloomConfig, dht: DHT, prefix: str, max_retries: int = 3):
+    def __init__(
+        self,
+        config: src.DistributedBloomConfig,
+        dht: DHT,
+        prefix: str,
+        max_retries: int = 3,
+        p2p: Optional[P2P] = None,
+        sequence_manager: Optional[RemoteSequenceManager] = None,
+    ):
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
         if prefix.endswith(UID_DELIMITER):
         if prefix.endswith(UID_DELIMITER):
             logger.warning(
             logger.warning(
@@ -39,12 +48,17 @@ class RemoteSequential(nn.Module):
         self.dht = dht
         self.dht = dht
         self.prefix = prefix
         self.prefix = prefix
         self.max_retries = max_retries
         self.max_retries = max_retries
-        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p())
-
-        block_uids = tuple(f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer))
-
-        logger.debug(f"Remote block uids: {block_uids}")
-        self.remote_sequence_info = RemoteSequenceInfo(dht, block_uids)
+        self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
+
+        block_uids = [f"{prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+        if sequence_manager is None:
+            logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
+            self.sequence_manager = RemoteSequenceManager(dht, block_uids)
+            self.is_subsequence = False
+        else:
+            assert isinstance(sequence_manager.block_uids, list)
+            logger.debug(f"Reusing sequence manager with {len(self.sequence_manager)}")
+            self.is_subsequence = self.sequence_manager.block_uids == block_uids
 
 
     def forward(self, inputs: torch.Tensor):
     def forward(self, inputs: torch.Tensor):
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
         assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3 and inputs.shape[-1] == self.config.n_embed
@@ -64,27 +78,38 @@ class RemoteSequential(nn.Module):
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
                         logging.debug(f"Caught {e} when running forward for block {block_index}", exc_info=True)
         return inputs
         return inputs
 
 
-    def __getitem__(self, block_index: int):
-        assert 0 <= block_index < self.config.n_layer
-        (module,) = _create_remote_modules_from_infos([self.remote_sequence_info.block_infos[block_index]], self.p2p)
-        return module
+    def __getitem__(self, ix: Union[int, slice]) -> Union[RemoteTransformerBlock, RemoteSequential]:
+        assert isinstance(ix, (int, slice))
+        if isinstance(ix, int):
+            assert 0 <= ix < self.config.n_layer
+            (module,) = _create_remote_modules_from_infos([self.sequence_manager.block_infos[ix]], self.p2p)
+            return module
+        else:
+            return RemoteSequential(
+                self.config,
+                self.dht,
+                prefix=self.prefix,
+                max_retries=self.max_retries,
+                p2p=self.p2p,
+                sequence_manager=self.sequence_manager[ix],
+            )
 
 
     def __iter__(self):
     def __iter__(self):
         for block_index in range(self.config.n_layer):
         for block_index in range(self.config.n_layer):
             yield self[block_index]
             yield self[block_index]
 
 
     def __len__(self):
     def __len__(self):
-        return len(self.remote_sequence_info)
+        return len(self.sequence_manager)
 
 
     def inference_session(self) -> RemoteSequentialInferenceSession:
     def inference_session(self) -> RemoteSequentialInferenceSession:
-        self.remote_sequence_info.update_()
-        return RemoteSequentialInferenceSession(self.remote_sequence_info, self.p2p)
+        self.sequence_manager.update_()
+        return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p)
 
 
 
 
 class RemoteSequentialInferenceSession:
 class RemoteSequentialInferenceSession:
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
     """An interface to a multi-step *inference* session for a sequence of remote transformer blocks"""
 
 
-    def __init__(self, remote_sequence_info: RemoteSequenceInfo, p2p: P2P):
+    def __init__(self, remote_sequence_info: RemoteSequenceManager, p2p: P2P):
         self.remote_sequence_info = remote_sequence_info
         self.remote_sequence_info = remote_sequence_info
         self.p2p = p2p
         self.p2p = p2p
         self.closed = False
         self.closed = False

+ 25 - 14
src/client/remote_sequence_info.py → src/client/sequence_manager.py

@@ -1,29 +1,27 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import threading
 import threading
-from typing import List, NamedTuple, Optional, Sequence, Tuple
+from typing import List, Optional, Sequence, Tuple, Union
 
 
-from hivemind import DHT, PeerID
+from hivemind import DHT, DHTExpiration
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src.data_structures import ModuleUID, RemoteModuleInfo, ServerState
+from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-Span = NamedTuple("Span", [("start", int), ("end", Optional[int]), ("peer_id", PeerID)])
-
-
-class RemoteSequenceInfo:
+class RemoteSequenceManager:
     """Keeps and updates the meta-information about which peers host which blocks"""
     """Keeps and updates the meta-information about which peers host which blocks"""
 
 
     dht: DHT
     dht: DHT
     block_uids: List[ModuleUID]
     block_uids: List[ModuleUID]
     block_infos: List[Optional[RemoteModuleInfo]]
     block_infos: List[Optional[RemoteModuleInfo]]
-    spans_by_priority: List[Span]  # sorted from best to worst
-    spans_containing_block: Tuple[List[Span]]
+    spans_by_priority: List[RemoteSpanInfo]  # sorted from best to worst
+    spans_containing_block: Tuple[List[RemoteSpanInfo], ...]
+    last_update_time: DHTExpiration
     lock_changes: threading.Lock
     lock_changes: threading.Lock
 
 
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
     def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID]):
@@ -32,6 +30,7 @@ class RemoteSequenceInfo:
         self.block_infos = [None] * len(self.block_uids)
         self.block_infos = [None] * len(self.block_uids)
         self.spans_by_priority = []
         self.spans_by_priority = []
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
         self.spans_containing_block = tuple(list() for _ in range(len(self.block_uids)))
+        self.last_update_time = -float("inf")
         self.lock_changes = threading.Lock()
         self.lock_changes = threading.Lock()
         self.update_()
         self.update_()
 
 
@@ -39,6 +38,18 @@ class RemoteSequenceInfo:
             assert info is not None, f"Found no remote peers for block {uid}"
             assert info is not None, f"Found no remote peers for block {uid}"
         assert self.spans_by_priority and self.spans_containing_block
         assert self.spans_by_priority and self.spans_containing_block
 
 
+    def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
+        """Get a RemoteSequenceManager for a sub-sequence of blocks"""
+        assert isinstance(ix, (int, slice))
+        if not isinstance(ix, slice):
+            ix = slice(int(ix), int(ix) + 1, 1)
+        with self.lock_changes:
+            subseq = RemoteSequenceManager(self.dht, self.block_uids[ix])
+            subseq.block_infos = self.block_infos[ix]
+            subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
+            subseq.last_update_time = self.last_update_time
+        return subseq
+
     def update_(self):
     def update_(self):
         with self.lock_changes:
         with self.lock_changes:
             self.update_block_infos_()
             self.update_block_infos_()
@@ -67,15 +78,15 @@ class RemoteSequenceInfo:
                 if server.state != ServerState.ONLINE:
                 if server.state != ServerState.ONLINE:
                     continue
                     continue
                 if peer_id not in active_spans:
                 if peer_id not in active_spans:
-                    active_spans[peer_id] = Span(start=block_index, end=block_index + 1, peer_id=peer_id)
+                    active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
                 else:  # peer_id in active_spans
                 else:  # peer_id in active_spans
-                    active_spans[peer_id] = active_spans[peer_id]._replace(end=block_index + 1)
+                    active_spans[peer_id].end = block_index + 1
 
 
             for peer_id in list(active_spans.keys()):
             for peer_id in list(active_spans.keys()):
                 if (
                 if (
-                    peer_id not in info.servers or
-                    info.servers[peer_id].state != ServerState.ONLINE or
-                    block_index == len(block_infos) - 1
+                    peer_id not in info.servers
+                    or info.servers[peer_id].state != ServerState.ONLINE
+                    or block_index == len(block_infos) - 1
                 ):
                 ):
                     closed_spans.append(active_spans.pop(peer_id))
                     closed_spans.append(active_spans.pop(peer_id))
         assert not active_spans
         assert not active_spans

+ 11 - 0
src/data_structures.py

@@ -23,5 +23,16 @@ class ServerInfo:
 
 
 @dataclass
 @dataclass
 class RemoteModuleInfo:
 class RemoteModuleInfo:
+    """A remote module that is served by one or more servers"""
+
     uid: ModuleUID
     uid: ModuleUID
     servers: Dict[PeerID, ServerInfo]
     servers: Dict[PeerID, ServerInfo]
+
+
+@dataclass
+class RemoteSpanInfo:
+    """A chain of remote blocks served by one specific remote peer"""
+
+    start: int
+    end: int
+    peer_id: PeerID

+ 6 - 2
src/dht_utils.py

@@ -136,8 +136,12 @@ async def _get_remote_module_infos(
             try:
             try:
                 peer_id = PeerID.from_base58(peer_id)
                 peer_id = PeerID.from_base58(peer_id)
                 state, throughput = server_info.value
                 state, throughput = server_info.value
-                if not (isinstance(state, int) and isinstance(throughput, float) and
-                        math.isfinite(throughput) and throughput >= 0.0):
+                if not (
+                    isinstance(state, int)
+                    and isinstance(throughput, float)
+                    and math.isfinite(throughput)
+                    and throughput >= 0.0
+                ):
                     raise ValueError(f"Invalid server info: {server_info}")
                     raise ValueError(f"Invalid server info: {server_info}")
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
                 servers[peer_id] = ServerInfo(ServerState(state), throughput)
             except (TypeError, ValueError) as e:
             except (TypeError, ValueError) as e:

+ 4 - 4
src/server/block_selection.py

@@ -9,10 +9,10 @@ def choose_best_blocks(num_blocks: int, remote_module_infos: List[Optional[Remot
         if module is None:
         if module is None:
             throughputs.append(0)
             throughputs.append(0)
             continue
             continue
-        throughputs.append(sum(server.throughput for server in module.servers.values()
-                               if server.state != ServerState.OFFLINE))
+        throughputs.append(
+            sum(server.throughput for server in module.servers.values() if server.state != ServerState.OFFLINE)
+        )
 
 
-    options = [(sorted(throughputs[i:i + num_blocks]), i)
-               for i in range(0, len(throughputs) - num_blocks + 1)]
+    options = [(sorted(throughputs[i : i + num_blocks]), i) for i in range(0, len(throughputs) - num_blocks + 1)]
     best_start = min(options)[1]
     best_start = min(options)[1]
     return list(range(best_start, best_start + num_blocks))
     return list(range(best_start, best_start + num_blocks))

+ 8 - 10
src/server/server.py

@@ -4,7 +4,7 @@ import multiprocessing as mp
 import random
 import random
 import threading
 import threading
 import time
 import time
-from typing import Dict, Literal, Optional, Sequence, Union
+from typing import Dict, Optional, Sequence, Union
 
 
 import torch
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
@@ -13,7 +13,7 @@ from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
 
-from src import declare_active_modules, BloomConfig
+from src import BloomConfig, declare_active_modules
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from src.dht_utils import get_remote_module_infos
 from src.dht_utils import get_remote_module_infos
@@ -98,7 +98,7 @@ class Server(threading.Thread):
         cls,
         cls,
         prefix: Optional[str],
         prefix: Optional[str],
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
-        throughput: Union[float, Literal['auto', 'eval']],
+        throughput: Union[float, str],
         num_blocks: Optional[int] = None,
         num_blocks: Optional[int] = None,
         block_indices: Optional[str] = None,
         block_indices: Optional[str] = None,
         num_handlers: Optional[int] = None,
         num_handlers: Optional[int] = None,
@@ -140,17 +140,15 @@ class Server(threading.Thread):
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         device = device or ("cuda" if torch.cuda.is_available() else "cpu")
         memory_cache = MemoryCache(device, cache_size_bytes)
         memory_cache = MemoryCache(device, cache_size_bytes)
 
 
-        assert isinstance(throughput, float) or throughput in ['auto', 'eval']
-        if throughput in ['auto', 'eval']:
-            throughput = get_host_throughput(device, force_eval=(throughput == 'eval'))
+        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+        if throughput in ["auto", "eval"]:
+            throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
 
 
         if isinstance(torch_dtype, str):
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
 
 
-        block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path, use_auth_token=use_auth_token
-        )
+        block_config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
 
 
         if block_indices is not None:
         if block_indices is not None:
             try:
             try:
@@ -288,7 +286,7 @@ class ModuleAnnouncerThread(threading.Thread):
         throughput: float,
         throughput: float,
         update_period: float = 30,
         update_period: float = 30,
         expiration: float,
         expiration: float,
-        **kwargs
+        **kwargs,
     ):
     ):
         super().__init__(**kwargs)
         super().__init__(**kwargs)
         self.module_backends = module_backends
         self.module_backends = module_backends

+ 14 - 13
src/server/throughput.py

@@ -20,10 +20,10 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
-DEFAULT_CACHE_PATH = Path(Path.home(), '.cache', project_name, 'throughput.json')
-DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, 'throughput.lock')
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
+DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
 
 
-SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], 'cli', 'speed_test.py')
+SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
 
 
 
 
 @dataclass
 @dataclass
@@ -43,7 +43,7 @@ def get_host_throughput(
 
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     os.makedirs(lock_path.parent, exist_ok=True)
-    with open(lock_path, 'wb') as lock_fd:
+    with open(lock_path, "wb") as lock_fd:
         logger.info("Loading throughput info")
         logger.info("Loading throughput info")
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
         # The OS will release the lock when lock_fd is closed or the process is killed
@@ -63,7 +63,7 @@ def get_host_throughput(
             info = measure_throughput_info()
             info = measure_throughput_info()
             try:
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
                 os.makedirs(cache_path.parent, exist_ok=True)
-                with open(cache_path, 'w') as cache_fd:
+                with open(cache_path, "w") as cache_fd:
                     json.dump(asdict(info), cache_fd)
                     json.dump(asdict(info), cache_fd)
             except Exception:
             except Exception:
                 logger.exception(f"Failed to save throughput info in {cache_path}")
                 logger.exception(f"Failed to save throughput info in {cache_path}")
@@ -73,29 +73,30 @@ def get_host_throughput(
 
 
 
 
 def measure_throughput_info() -> ThroughputInfo:
 def measure_throughput_info() -> ThroughputInfo:
-    logger.info("Measuring network, CPU, and GPU throughput. "
-                "This takes about a minute and will be cached for future runs")
+    logger.info(
+        "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
+    )
 
 
     # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
     # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
-    config = BloomConfig.from_pretrained('bigscience/test-bloomd-6b3')
+    config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
 
 
     network_rps = measure_network_rps(config)
     network_rps = measure_network_rps(config)
 
 
-    device_rps = {'cpu': measure_device_rps('cpu', config)}
+    device_rps = {"cpu": measure_device_rps("cpu", config)}
     if torch.cuda.is_available():
     if torch.cuda.is_available():
-        device_rps['cuda'] = measure_device_rps('cuda', config)
+        device_rps["cuda"] = measure_device_rps("cuda", config)
 
 
     return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
     return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
 
 
 
 
 def measure_network_rps(config: BloomConfig) -> float:
 def measure_network_rps(config: BloomConfig) -> float:
-    proc = subprocess.run([SPEED_TEST_PATH, '--json'], capture_output=True)
+    proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
     if proc.returncode != 0:
     if proc.returncode != 0:
         raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
         raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
     network_info = json.loads(proc.stdout)
     network_info = json.loads(proc.stdout)
 
 
     bits_per_request = config.hidden_size * 32
     bits_per_request = config.hidden_size * 32
-    network_rps = min(network_info['download'], network_info['upload']) / bits_per_request
+    network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
 
 
     logger.info(
     logger.info(
         f"Network throughput: "
         f"Network throughput: "
@@ -120,7 +121,7 @@ def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n
             elapsed += time.perf_counter() - start_time
             elapsed += time.perf_counter() - start_time
         device_rps = n_steps / elapsed
         device_rps = n_steps / elapsed
 
 
-    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == 'cuda' else 'CPU'
+    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
     logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
     logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
 
 
     return device_rps
     return device_rps

BIN
tests/test.id


+ 5 - 2
tests/test_block_exact_match.py

@@ -3,6 +3,7 @@ import os
 
 
 import hivemind
 import hivemind
 import torch
 import torch
+import transformers
 
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
 from src.client.remote_block import RemoteTransformerBlock
@@ -19,16 +20,18 @@ if not BLOCK_UID:
     raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
     raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
 
 
 REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
 REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
+REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
 
 
 
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+
     remote_block = get_remote_module(dht, BLOCK_UID)
     remote_block = get_remote_module(dht, BLOCK_UID)
     assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
     assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
     assert isinstance(remote_block, RemoteTransformerBlock)
     assert isinstance(remote_block, RemoteTransformerBlock)
+    ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
 
 
-    inputs = torch.randn(1, 8, 4096)
+    inputs = torch.randn(1, 8, ref_config.hidden_size)
     (outputs_forward,) = remote_block(inputs)
     (outputs_forward,) = remote_block(inputs)
 
 
     outputs_inference = []
     outputs_inference = []

+ 97 - 0
tests/test_chained_calls.py

@@ -0,0 +1,97 @@
+######
+# Warning:torch this test is a work in progress. It will be modified soon.
+# - if you want more stable tests, see test_block_exact_match
+# - if you want to figure out chained inference, ask yozh
+
+import os
+
+import hivemind
+import torch
+import transformers
+from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
+
+from src.bloom.from_pretrained import load_pretrained_block
+from src.client.remote_block import RemoteTransformerBlock
+from src.dht_utils import get_remote_module
+
+INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
+if not INITIAL_PEERS:
+    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
+INITIAL_PEERS = INITIAL_PEERS.split()
+
+
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
+
+REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
+
+
+def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+    ]
+    inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
+    outputs_rpc = remote_block.forward(inputs)[0]
+    outputs_rpc.sum().backward()
+    grads_rpc = inputs.grad
+
+    inputs.grad = None
+    hidden_states = inputs
+    for ref_block in ref_blocks:
+        hidden_states = ref_block.forward(hidden_states)[0]
+    outputs_ref = hidden_states
+    outputs_ref.sum().backward()
+    grads_ref = inputs.grad
+
+    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
+    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
+
+
+def test_chained_inference_exact_match(atol_inference=1e-4):
+    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
+    remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0")
+    assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
+    assert isinstance(remote_block, RemoteTransformerBlock)
+
+    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
+    remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
+
+    inputs = torch.randn(1, 8, config.hidden_size)
+
+    outputs_inference = []
+    with remote_block.inference_session() as sess:
+        for i in range(inputs.shape[1]):
+            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+    outputs_inference = torch.cat(outputs_inference, dim=1)
+
+    ref_blocks = [
+        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+    ]
+    outputs_ref = []
+    caches = [None, None]
+    for i in range(inputs.shape[1]):
+        new_caches = []
+        hidden_states = inputs[:, i : i + 1, :]
+        for ref_block, cache in zip(ref_blocks, caches):
+            with torch.no_grad():
+                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
+                new_caches.append(new_cache)
+
+        outputs_ref.append(hidden_states)
+        caches = new_caches
+    outputs_ref = torch.cat(outputs_ref, dim=1)
+    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

+ 0 - 59
tests/test_chained_forward_backward.py

@@ -1,59 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
-
-# seq_length > 128: rpc_forward_stream & rpc_backward_stream
-# seq_length <= 128: rpc_forward & rpc_backward
-def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    (remote_block,) = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4 bloom6b3.5", remote_block._info.peer_id)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
-    ]
-    inputs = torch.randn(1, seq_length, 4096, requires_grad=True)
-    outputs_rpc = remote_block.forward(inputs)[0]
-    outputs_rpc.sum().backward()
-    grads_rpc = inputs.grad
-
-    inputs.grad = None
-    hidden_states = inputs
-    for ref_block in ref_blocks:
-        hidden_states = ref_block.forward(hidden_states)[0]
-    outputs_ref = hidden_states
-    outputs_ref.sum().backward()
-    grads_ref = inputs.grad
-
-    assert torch.allclose(outputs_ref, outputs_rpc, rtol=0, atol=atol_forward)
-    assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)

+ 0 - 64
tests/test_chained_inference.py

@@ -1,64 +0,0 @@
-######
-# Warning:torch this test is a work in progress. It will be modified soon.
-# - if you want more stable tests, see test_block_exact_match
-# - if you want to figure out chained inference, ask yozh
-
-import os
-
-import hivemind
-import torch
-from hivemind.moe.expert_uid import ExpertInfo
-
-from src.bloom.from_pretrained import load_pretrained_block
-from src.client.remote_block import RemoteTransformerBlock
-from src.dht_utils import get_remote_module
-
-INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
-if not INITIAL_PEERS:
-    raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
-INITIAL_PEERS = INITIAL_PEERS.split()
-
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
-
-
-def test_remote_block_exact_match(atol_inference=1e-4):
-    dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-
-    _ = remote_block.info  # lazy-init info now, because otherwise we will _break_ info init by chaning _info
-    remote_block._info = ExpertInfo("bloom6b3.3 bloom6b3.4", remote_block._info.peer_id)
-
-    inputs = torch.randn(1, 8, 4096)
-
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
-
-    ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-    ]
-    outputs_ref = []
-    caches = [None, None]
-    for i in range(inputs.shape[1]):
-        new_caches = []
-        hidden_states = inputs[:, i : i + 1, :]
-        for ref_block, cache in zip(ref_blocks, caches):
-            with torch.no_grad():
-                hidden_states, new_cache = ref_block.forward(hidden_states, use_cache=True, layer_past=cache)
-                new_caches.append(new_cache)
-
-        outputs_ref.append(hidden_states)
-        caches = new_caches
-    outputs_ref = torch.cat(outputs_ref, dim=1)
-    assert torch.allclose(outputs_ref, outputs_inference, rtol=0, atol=atol_inference)

+ 23 - 19
tests/test_full_model.py

@@ -24,9 +24,10 @@ if not MODEL_NAME:
 REF_NAME = os.environ.get("REF_NAME")
 REF_NAME = os.environ.get("REF_NAME")
 
 
 
 
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
     assert len(model.transformer.h) == model.config.n_layer
 
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
@@ -35,26 +36,29 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="
     logger.info("Forward outputs are finite")
     logger.info("Forward outputs are finite")
 
 
     if REF_NAME:
     if REF_NAME:
-        ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-        # note: this creates a dummy mask to make the test compatible with older transformer versions
-        # prior to https://github.com/huggingface/transformers/pull/17837
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
-        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+        with torch.no_grad():
+            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+            # note: this creates a dummy mask to make the test compatible with older transformer versions
+            # prior to https://github.com/huggingface/transformers/pull/17837
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
+            del ref_model, ref_outputs
     else:
     else:
         logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
         logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
 
 
-    embs = model.transformer.word_embeddings(test_inputs)
-    embs = model.transformer.word_embeddings_layernorm(embs)
-    recurrent_outputs = []
-    with model.transformer.h.inference_session() as sess:
-        for t in range(embs.shape[1]):
-            recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
-    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
-    recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-    dictionary = model.transformer.word_embeddings.weight.t()
-    recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-    recurrent_outputs = (recurrent_outputs @ dictionary).float()
+    with torch.inference_mode():
+        embs = model.transformer.word_embeddings(test_inputs)
+        embs = model.transformer.word_embeddings_layernorm(embs)
+        recurrent_outputs = []
+        with model.transformer.h.inference_session() as sess:
+            for t in range(embs.shape[1]):
+                recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+        recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
+        recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
+
+        dictionary = model.transformer.word_embeddings.weight.t()
+        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
+        recurrent_outputs = (recurrent_outputs @ dictionary).float()
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
     logger.info("Inference is consistent with forward")
     logger.info("Inference is consistent with forward")

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików