Przeglądaj źródła

Add LLaMA support (#323)

This PR:

1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.

    - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
    - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).

2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.

3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).

4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.

Upgrade instructions:

- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and  `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
Alexander Borzunov 2 lat temu
rodzic
commit
cb3f018f9f
45 zmienionych plików z 1073 dodań i 853 usunięć
  1. 6 60
      .github/workflows/run-tests.yaml
  2. 3 1
      setup.cfg
  3. 11 1
      src/petals/__init__.py
  4. 0 0
      src/petals/bloom/__init__.py
  5. 0 62
      src/petals/bloom/block.py
  6. 0 132
      src/petals/bloom/from_pretrained.py
  7. 0 20
      src/petals/cli/config.json
  8. 0 96
      src/petals/cli/convert_model.py
  9. 1 1
      src/petals/cli/inference_one_block.py
  10. 1 1
      src/petals/cli/run_server.py
  11. 0 6
      src/petals/client/__init__.py
  12. 94 0
      src/petals/client/from_pretrained.py
  13. 28 44
      src/petals/client/lm_head.py
  14. 88 0
      src/petals/client/ptune.py
  15. 0 268
      src/petals/client/remote_model.py
  16. 3 4
      src/petals/client/remote_sequential.py
  17. 8 1
      src/petals/client/routing/sequence_manager.py
  18. 2 0
      src/petals/models/__init__.py
  19. 7 0
      src/petals/models/bloom/__init__.py
  20. 32 0
      src/petals/models/bloom/block.py
  21. 35 0
      src/petals/models/bloom/config.py
  22. 134 0
      src/petals/models/bloom/model.py
  23. 7 0
      src/petals/models/llama/__init__.py
  24. 87 0
      src/petals/models/llama/block.py
  25. 35 0
      src/petals/models/llama/config.py
  26. 152 0
      src/petals/models/llama/model.py
  27. 11 10
      src/petals/server/backend.py
  28. 4 6
      src/petals/server/block_utils.py
  29. 175 0
      src/petals/server/from_pretrained.py
  30. 35 29
      src/petals/server/server.py
  31. 12 10
      src/petals/server/throughput.py
  32. 1 0
      src/petals/utils/__init__.py
  33. 23 0
      src/petals/utils/auto_config.py
  34. 15 13
      src/petals/utils/convert_block.py
  35. 5 3
      src/petals/utils/disk_cache.py
  36. 19 1
      src/petals/utils/version.py
  37. 2 2
      tests/test_aux_functions.py
  38. 12 58
      tests/test_block_exact_match.py
  39. 2 2
      tests/test_chained_calls.py
  40. 7 8
      tests/test_dtype.py
  41. 2 2
      tests/test_full_model.py
  42. 10 8
      tests/test_remote_sequential.py
  43. 2 2
      tests/test_sequence_manager.py
  44. 1 1
      tests/test_server_stats.py
  45. 1 1
      tests/test_tensor_parallel.py

+ 6 - 60
.github/workflows/run-tests.yaml

@@ -6,57 +6,8 @@ on:
   pull_request:
   pull_request:
 
 
 jobs:
 jobs:
-  convert-model:
-    runs-on: ubuntu-latest
-    env:
-      BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
-    timeout-minutes: 15
-    steps:
-      - name: Checkout
-        uses: actions/checkout@v3
-      - name: Check if the model is cached
-        id: cache-model
-        uses: actions/cache@v3
-        with:
-          path: ~/converted_ok
-          key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-      - name: Set up Python
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        uses: actions/setup-python@v3
-        with:
-          python-version: 3.9
-      - name: Cache dependencies
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        uses: actions/cache@v3
-        with:
-          path: ~/.cache/pip
-          key: Key-v1-3.9-${{ hashFiles('setup.cfg') }}
-      - name: Install dependencies
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          python -m pip install --upgrade pip
-          pip install .
-      - name: Delete any test models older than 1 week
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
-      - name: Delete previous version of this model, if exists
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
-          python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
-          repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true
-      - name: Convert model and push to hub
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-          python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
-            --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
-            --resize_token_embeddings 50000 && touch ~/converted_ok
-
   run-tests:
   run-tests:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
-    needs: convert-model
     strategy:
     strategy:
       matrix:
       matrix:
         python-version: [ '3.7', '3.8', '3.9', '3.10' ]
         python-version: [ '3.7', '3.8', '3.9', '3.10' ]
@@ -80,8 +31,7 @@ jobs:
           pip install .[dev]
           pip install .[dev]
       - name: Test
       - name: Test
         run: |
         run: |
-          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-          export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG
+          export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
@@ -104,23 +54,19 @@ jobs:
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
           SERVER3_PID=$!
           SERVER3_PID=$!
 
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \
-            --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
-          SERVER4_PID=$!
-
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu  --torch_dtype float32 &> server5.log &
-          SERVER5_PID=$!
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
+          SERVER4_PID=$!
 
 
           tail -n 100 -f server*.log &
           tail -n 100 -f server*.log &
           LOGGER_PID=$!
           LOGGER_PID=$!
           sleep 30  # wait for servers to download layers
           sleep 30  # wait for servers to download layers
 
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
 
 
           pytest tests --durations=0 --durations-min=1.0 -v
           pytest tests --durations=0 --durations-min=1.0 -v
 
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
 
 
-          kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID
+          kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           echo "Done!"
           echo "Done!"

+ 3 - 1
setup.cfg

@@ -35,7 +35,8 @@ install_requires =
     bitsandbytes==0.38.0.post2
     bitsandbytes==0.38.0.post2
     accelerate>=0.16.0,<1.0.0
     accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
-    transformers>=4.25.1,<5.0.0
+    tokenizers>=0.13.3
+    transformers>=4.30.1,<5.0.0
     speedtest-cli==2.1.3
     speedtest-cli==2.1.3
     hivemind==1.1.8
     hivemind==1.1.8
     tensor_parallel==1.0.23
     tensor_parallel==1.0.23
@@ -43,6 +44,7 @@ install_requires =
     async-timeout>=4.0.2
     async-timeout>=4.0.2
     cpufeature>=0.2.0
     cpufeature>=0.2.0
     packaging>=20.9
     packaging>=20.9
+    sentencepiece>=0.1.99
 
 
 [options.extras_require]
 [options.extras_require]
 dev =
 dev =

+ 11 - 1
src/petals/__init__.py

@@ -1,11 +1,21 @@
 import os
 import os
 
 
 import hivemind
 import hivemind
+import transformers
+from packaging import version
 
 
 from petals.client import *
 from petals.client import *
+from petals.models import *
+from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 from petals.utils.logging import initialize_logs as _initialize_logs
 
 
-__version__ = "1.1.5"
+__version__ = "1.2.0.dev0"
+
+
+if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
+    assert (
+        version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0"
 
 
 
 
 def _override_bfloat16_mode_default():
 def _override_bfloat16_mode_default():

+ 0 - 0
src/petals/bloom/__init__.py


+ 0 - 62
src/petals/bloom/block.py

@@ -1,62 +0,0 @@
-"""
-Bloom intermediate layer
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-import os
-from typing import Optional, Tuple
-
-import torch.nn.quantized.dynamic.modules.linear
-import transformers
-from packaging import version
-from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
-
-if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
-    assert (
-        version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
-
-
-class WrappedBloomBlock(BloomBlock):
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        *args,
-        attention_mask: Optional[torch.Tensor] = None,
-        alibi: Optional[torch.Tensor] = None,
-        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-        **kwargs
-    ):
-        assert attention_mask is None
-        batch_size, seq_length = hidden_states.shape[:2]
-        past_length = 0 if layer_past is None else layer_past[0].shape[-1]
-        seq_length_with_past = seq_length + past_length
-        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
-        if alibi is None:
-            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
-        return super().forward(
-            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
-        )
-
-    def _prepare_attn_mask(
-        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
-    ) -> torch.BoolTensor:
-        # create causal mask
-        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
-        combined_attention_mask = None
-        device = attention_mask.device
-        _, src_length = input_shape
-
-        if src_length > 1:
-            combined_attention_mask = _make_causal_mask(
-                torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
-            )
-
-        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
-        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
-        combined_attention_mask = (
-            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
-        )
-
-        return combined_attention_mask

+ 0 - 132
src/petals/bloom/from_pretrained.py

@@ -1,132 +0,0 @@
-"""
-Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
-If necessary, one can rewrite this to implement a different behavior, such as:
- - loading files from a local data source (e.g. S3)
- - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
- - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
-
-"""
-from __future__ import annotations
-
-import itertools
-import time
-from typing import Optional, OrderedDict, Union
-
-import torch
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from hivemind.utils.logging import get_logger
-from transformers.modeling_utils import WEIGHTS_NAME
-from transformers.models.bloom.configuration_bloom import BloomConfig
-from transformers.utils import get_file_from_repo
-
-from petals.bloom.block import WrappedBloomBlock
-from petals.server.block_utils import get_block_size, resolve_block_dtype
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
-
-logger = get_logger(__name__)
-
-CLIENT_BRANCH = "main"
-BLOCK_BRANCH_PREFIX = "block_"
-
-
-def load_pretrained_block(
-    converted_model_name_or_path: str,
-    block_index: int,
-    config: Optional[BloomConfig] = None,
-    torch_dtype: Union[torch.dtype, str] = "auto",
-    use_auth_token: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    max_disk_space: Optional[int] = None,
-) -> WrappedBloomBlock:
-    """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
-    assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-    torch_dtype = resolve_block_dtype(config, torch_dtype)
-
-    if config is None:
-        config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
-    if cache_dir is None:
-        cache_dir = DEFAULT_CACHE_DIR
-
-    with init_empty_weights():
-        block = WrappedBloomBlock(config)
-
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-        max_disk_space=max_disk_space,
-    )
-
-    # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=True)
-    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
-
-    for param_name, _ in block.named_parameters():
-        assert param_name in state_dict, f"{param_name} not in state dict"
-        param = state_dict[param_name]
-        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
-            param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
-
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
-    return block
-
-
-def _load_state_dict(
-    pretrained_model_name_or_path: str,
-    block_index: int,
-    config: BloomConfig,
-    *,
-    use_auth_token: Optional[str] = None,
-    cache_dir: str,
-    max_disk_space: Optional[int] = None,
-    min_backoff: float = 5,
-) -> OrderedDict[str, torch.Tensor]:
-    revision = BLOCK_BRANCH_PREFIX + str(block_index)
-
-    # First, try to find the weights locally
-    try:
-        with allow_cache_reads(cache_dir):
-            archive_file = get_file_from_repo(
-                pretrained_model_name_or_path,
-                filename=WEIGHTS_NAME,
-                revision=revision,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-                local_files_only=True,
-            )
-            if archive_file is not None:
-                return torch.load(archive_file, map_location="cpu")
-    except Exception:
-        logger.debug(
-            f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
-        )
-
-    # If not found, ensure that we have enough disk space to download them (maybe remove something)
-    for attempt_no in itertools.count():
-        try:
-            with allow_cache_writes(cache_dir):
-                block_size = get_block_size(config, "disk")
-                free_disk_space_for(
-                    pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
-                )
-
-                archive_file = get_file_from_repo(
-                    pretrained_model_name_or_path,
-                    filename=WEIGHTS_NAME,
-                    revision=revision,
-                    use_auth_token=use_auth_token,
-                    cache_dir=cache_dir,
-                    local_files_only=False,
-                )
-                return torch.load(archive_file, map_location="cpu")
-        except Exception as e:
-            delay = min_backoff * (2**attempt_no)
-            logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
-            time.sleep(delay)
-
-
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 0 - 20
src/petals/cli/config.json

@@ -1,20 +0,0 @@
-{
-  "apply_residual_connection_post_layernorm": false,
-  "attention_dropout": 0.0,
-  "attention_softmax_in_fp32": true,
-  "bos_token_id": 1,
-  "eos_token_id": 2,
-  "hidden_dropout": 0.0,
-  "initializer_range": 0.02,
-  "layer_norm_epsilon": 1e-05,
-  "masked_softmax_fusion": true,
-  "model_type": "bloom",
-  "n_embed": 14336,
-  "n_layer": 70,
-  "num_attention_heads": 112,
-  "pretraining_tp": 4,
-  "slow_but_exact": false,
-  "transformers_version": "4.20.0.dev0",
-  "use_cache": true,
-  "vocab_size": 250880
-}

+ 0 - 96
src/petals/cli/convert_model.py

@@ -1,96 +0,0 @@
-import argparse
-import os
-
-import psutil
-import torch.backends.quantized
-import torch.nn as nn
-import transformers
-from hivemind.utils.logging import get_logger
-from huggingface_hub import HfApi, Repository
-from tqdm.auto import tqdm
-from transformers.models.bloom.modeling_bloom import BloomModel
-
-from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
-from petals.client import DistributedBloomConfig
-
-logger = get_logger(__name__)
-
-
-def main():
-    parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
-
-    parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
-    parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
-    parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
-    parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
-    parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
-    parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
-    parser.add_argument(
-        "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
-    )
-    parser.add_argument(
-        "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
-    )
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
-    args = parser.parse_args()
-
-    free_ram_gb = psutil.virtual_memory().available / 2**30
-    if args.model == "bigscience/bloom" and free_ram_gb < 400:
-        logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
-
-    assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
-    if os.path.exists(args.output_path) and (
-        len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
-    ):
-        raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
-
-    logger.info(f"Loading source model {args.model} (this may take a few minutes)")
-    config = DistributedBloomConfig.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    config.dht_prefix = args.output_repo
-
-    model = BloomModel.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
-    )
-    if args.resize_token_embeddings:
-        logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
-        model.resize_token_embeddings(args.resize_token_embeddings)
-        config.vocab_size = args.resize_token_embeddings
-
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    os.makedirs(args.output_path, exist_ok=True)
-
-    api = HfApi(token=args.use_auth_token)
-    api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
-    repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
-    repo.git_pull()
-
-    transformer_blocks = model.h
-    logger.info(
-        f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
-        f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
-    )
-    for i, block in enumerate(tqdm(transformer_blocks)):
-        repo.git_checkout(args.client_branch, create_branch_ok=True)
-        with repo.commit(
-            commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
-        ):
-            torch.save(block.state_dict(), "./pytorch_model.bin")
-
-    logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
-    repo.git_checkout(args.client_branch, create_branch_ok=True)
-    with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
-        model.h = nn.ModuleList()
-        model.save_pretrained(".")
-        tokenizer.save_pretrained(".")
-        config.save_pretrained(".")
-
-    logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
-
-
-if __name__ == "__main__":
-    main()

+ 1 - 1
src/petals/cli/inference_one_block.py

@@ -6,7 +6,7 @@ from tqdm.auto import trange
 from transformers import BloomConfig
 from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import build_alibi_tensor
 from transformers.models.bloom.modeling_bloom import build_alibi_tensor
 
 
-from petals.bloom.block import BloomBlock
+from petals.models.bloom.block import BloomBlock
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 

+ 1 - 1
src/petals/cli/run_server.py

@@ -87,7 +87,7 @@ def main():
     parser.add_argument('--alloc_timeout', type=float, default=60,
     parser.add_argument('--alloc_timeout', type=float, default=60,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
                              'before rejecting the request')
-    parser.add_argument('--revision', type=str, default='main',
+    parser.add_argument('--revision', type=str, default=None,
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
 

+ 0 - 6
src/petals/client/__init__.py

@@ -1,10 +1,4 @@
 from petals.client.inference_session import InferenceSession
 from petals.client.inference_session import InferenceSession
-from petals.client.remote_model import (
-    DistributedBloomConfig,
-    DistributedBloomForCausalLM,
-    DistributedBloomForSequenceClassification,
-    DistributedBloomModel,
-)
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 94 - 0
src/petals/client/from_pretrained.py

@@ -0,0 +1,94 @@
+import contextlib
+import json
+import os
+import re
+import tempfile
+import threading
+from typing import List, Optional, Tuple, Union
+
+import torch
+from hivemind.utils.logging import get_logger
+from transformers import BloomPreTrainedModel, modeling_utils
+
+from petals.utils.version import get_compatible_model_repo
+
+logger = get_logger(__name__)
+
+
+class FromPretrainedMixin:
+    @classmethod
+    def from_pretrained(
+        cls,
+        model_name_or_path: Union[str, os.PathLike, None],
+        *args,
+        low_cpu_mem_usage: Optional[bool] = None,
+        torch_dtype: Optional[Union[str, torch.dtype]] = None,
+        **kwargs,
+    ):
+        model_name_or_path = get_compatible_model_repo(model_name_or_path)
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        if torch_dtype is None:
+            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
+            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
+            torch_dtype = "auto"
+
+        with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
+            return super().from_pretrained(
+                model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
+            )
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    ).replace(
+        "torch_dtype (`str` or `torch.dtype`, *optional*)",
+        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
+    )
+
+
+_shard_config = threading.local()
+_shard_config.ignored_keys = None
+
+
+@contextlib.contextmanager
+def ignore_keys(patterns: List[str]):
+    try:
+        prev_patterns = _shard_config.ignored_keys
+        _shard_config.ignored_keys = patterns
+        yield
+    finally:
+        _shard_config.ignored_keys = prev_patterns
+
+
+def patched_get_checkpoint_shard_files(
+    pretrained_model_name_or_path, index_filename, *args, **kwargs
+) -> Tuple[List[str], dict]:
+    """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
+
+    should_ignore_keys = _shard_config.ignored_keys is not None
+    tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
+    with tempdir_ctx as tempdir:
+        if should_ignore_keys:
+            with open(index_filename) as f:
+                index = json.load(f)
+            n_original_shards = len(set(index["weight_map"].values()))
+
+            index["weight_map"] = {
+                param_name: filename
+                for param_name, filename in index["weight_map"].items()
+                if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
+            }
+            n_loaded_shards = len(set(index["weight_map"].values()))
+            logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
+
+            # Replace the original index with a patched JSON, where ignored keys are removed
+            index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
+            with open(index_filename, "w") as f:
+                json.dump(index, f)
+
+        return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
+
+
+original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
+modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files

+ 28 - 44
src/petals/bloom/modeling_utils.py → src/petals/client/lm_head.py

@@ -1,10 +1,6 @@
-"""
-PyTorch BLOOM model that implements several memory-efficient modes.
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-
+import dataclasses
 import platform
 import platform
+from typing import Optional, Union
 
 
 import psutil
 import psutil
 import torch
 import torch
@@ -12,21 +8,30 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 import torch.utils.checkpoint
 from hivemind import get_logger
 from hivemind import get_logger
 from torch import nn
 from torch import nn
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class LMHead(nn.Module):
-    """
-    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
-    """
+@dataclasses.dataclass
+class LMHeadConfig:
+    # This settings matter for running the client with dtype bfloat16 on CPU.
+    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
+    use_chunked_forward: Union[str, bool] = "auto"
+    chunked_forward_step: int = 16384
+
 
 
-    def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
+class LMHead(nn.Module):
+    def __init__(self, config: PretrainedConfig):
         super().__init__()
         super().__init__()
-        self.word_embeddings = word_embeddings
+
+        if not config.tie_word_embeddings:
+            self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False))
+        else:
+            self.weight = None  # Will be set to get_input_embeddings().weight during loading the model
+        self.bias = None
+        self.in_features = config.hidden_size  # Similar to nn.Linear attributes
+        self.out_features = config.vocab_size
 
 
         self.use_chunked_forward = config.use_chunked_forward
         self.use_chunked_forward = config.use_chunked_forward
         if self.use_chunked_forward == "auto":
         if self.use_chunked_forward == "auto":
@@ -42,35 +47,17 @@ class LMHead(nn.Module):
         self.chunked_forward_step = config.chunked_forward_step
         self.chunked_forward_step = config.chunked_forward_step
         self._bf16_warning_shown = False
         self._bf16_warning_shown = False
 
 
-    @property
-    def in_features(self) -> int:
-        return self.word_embeddings.num_embeddings
-
-    @property
-    def out_features(self) -> int:
-        return self.word_embeddings.embedding_dim
-
-    @property
-    def weight(self):
-        return self.word_embeddings.weight
-
-    @property
-    def bias(self):
-        return None
-
     def forward(self, hidden_states):
     def forward(self, hidden_states):
-        word_embeddings = self.word_embeddings.weight
-
         if (
         if (
-            word_embeddings.dtype in [torch.float16, torch.bfloat16]
-            and word_embeddings.device.type == "cpu"
+            self.weight.dtype in [torch.float16, torch.bfloat16]
+            and self.weight.device.type == "cpu"
             and self.use_chunked_forward
             and self.use_chunked_forward
         ):
         ):
             lm_logits = self.chunked_forward(hidden_states)
             lm_logits = self.chunked_forward(hidden_states)
         else:
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
             # Switch dtype in case word_embeddings are fp16/bf16
-            hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings)
+            hidden_states = hidden_states.to(self.weight.dtype)
+            lm_logits = F.linear(hidden_states, self.weight)
         return lm_logits
         return lm_logits
 
 
     def chunked_forward(self, hidden_states):
     def chunked_forward(self, hidden_states):
@@ -80,20 +67,17 @@ class LMHead(nn.Module):
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
 
 
         if not self._bf16_warning_shown:
         if not self._bf16_warning_shown:
-            if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
+            if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
                 logger.warning(
                 logger.warning(
                     "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
                     "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
                     "Consider loading the model with torch_dtype='float32'"
                     "Consider loading the model with torch_dtype='float32'"
                 )
                 )
             self._bf16_warning_shown = True
             self._bf16_warning_shown = True
 
 
-        word_embeddings = self.word_embeddings.weight
-        num_embeddings = self.word_embeddings.num_embeddings
-
         hidden_states = hidden_states.float()
         hidden_states = hidden_states.float()
-        output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
+        output = torch.empty(*hidden_states.shape[:-1], self.out_features)
 
 
-        for i in range(0, num_embeddings, self.chunked_forward_step):
-            chunk = word_embeddings[i : i + self.chunked_forward_step].float()
+        for i in range(0, self.out_features, self.chunked_forward_step):
+            chunk = self.weight[i : i + self.chunked_forward_step].float()
             output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
             output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
         return output
         return output

+ 88 - 0
src/petals/client/ptune.py

@@ -0,0 +1,88 @@
+import dataclasses
+from contextlib import contextmanager
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from hivemind import get_logger
+from transformers import PretrainedConfig
+
+from petals.utils.misc import DUMMY
+
+logger = get_logger(__name__)
+
+
+@dataclasses.dataclass
+class PTuneConfig:
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
+
+
+class PTuneMixin:
+    _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
+
+    def init_prompts(self, config: PretrainedConfig) -> None:
+        if config.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            with force_non_empty_weights():
+                # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
+                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
+                if config.tuning_mode == "deep_ptune":
+                    self.intermediate_prompt_embeddings = nn.Embedding(
+                        self.pre_seq_len,
+                        config.num_hidden_layers * config.hidden_size,
+                        # ^-- TODO: should be num_hidden_layers - 1
+                        dtype=torch.float32,
+                    )
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad = value
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.config.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size,
+                self.pre_seq_len,
+                self.config.num_hidden_layers,
+                self.config.hidden_size
+                # TODO: should be num_hidden_layers - 1
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+
+        dtype = self.word_embeddings.weight.dtype
+        return prompts.to(dtype), intermediate_prompts.to(dtype)
+
+
+_original_register_parameter = nn.Module.register_parameter
+
+
+@contextmanager
+def force_non_empty_weights():
+    """
+    This context manager allows to bypass the accelerate.init_empty_weights() context manager
+    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
+    The transformers library should replace all meta tensors by empty tensors by itself
+    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
+
+    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
+    """
+
+    try:
+        possibly_patched_register_parameter = nn.Module.register_parameter
+        nn.Module.register_parameter = _original_register_parameter
+        yield
+    finally:
+        nn.Module.register_parameter = possibly_patched_register_parameter

+ 0 - 268
src/petals/client/remote_model.py

@@ -1,268 +0,0 @@
-from contextlib import contextmanager
-from typing import List, Optional, Union
-
-import hivemind
-import torch
-import torch.nn as nn
-from hivemind.utils.logging import get_logger
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.models.bloom import (
-    BloomConfig,
-    BloomForCausalLM,
-    BloomForSequenceClassification,
-    BloomModel,
-    BloomPreTrainedModel,
-)
-
-from petals.bloom.modeling_utils import LMHead
-from petals.client.remote_generation import RemoteGenerationMixin
-from petals.client.remote_sequential import RemoteSequential
-from petals.client.routing.sequence_manager import SequenceManagerConfig
-from petals.constants import PUBLIC_INITIAL_PEERS
-from petals.utils.misc import DUMMY
-
-logger = get_logger(__name__)
-
-
-class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
-    """
-    A bloom config that contains information about DHT peers.
-    To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
-    """
-
-    initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
-    dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
-    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
-
-    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
-    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
-
-    # This settings matter for running the client with dtype bfloat16 on CPU.
-    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
-    use_chunked_forward: Union[str, bool] = "auto"
-    chunked_forward_step: int = 16384
-
-
-original_register_parameter = nn.Module.register_parameter
-
-
-@contextmanager
-def force_non_empty_weights():
-    """
-    This context manager allows to bypass the accelerate.init_empty_weights() context manager
-    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
-    The transformers library should replace all meta tensors by empty tensors by itself
-    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
-
-    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
-    """
-
-    try:
-        possibly_patched_register_parameter = nn.Module.register_parameter
-        nn.Module.register_parameter = original_register_parameter
-        yield
-    finally:
-        nn.Module.register_parameter = possibly_patched_register_parameter
-
-
-class _FromPretrainedDefaultsMixin:
-    @classmethod
-    def from_pretrained(
-        cls,
-        *args,
-        low_cpu_mem_usage: Optional[bool] = None,
-        torch_dtype: Optional[Union[str, torch.dtype]] = None,
-        **kwargs,
-    ):
-        if low_cpu_mem_usage is None:
-            low_cpu_mem_usage = True
-        if torch_dtype is None:
-            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
-            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
-            torch_dtype = "auto"
-        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
-
-    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
-        "low_cpu_mem_usage(`bool`, *optional*)",
-        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
-    ).replace(
-        "torch_dtype (`str` or `torch.dtype`, *optional*)",
-        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
-    )
-
-
-class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
-    """BloomModel, but all transformer layers are hosted by the swarm"""
-
-    _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
-        r"^(intermediate_)?prompt_embeddings\.weight$",
-    ]
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
-        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
-
-        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
-        super().__init__(config)
-        assert len(self.h) == 0
-        config.n_layer = n_layer
-
-        self.h = RemoteSequential(config, dht=dht)
-
-        # Forbid accumulate grads for embeddings and layernorm
-        self.set_requires_grad(False)
-
-        if config.tuning_mode and "ptune" in config.tuning_mode:
-            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-            self.pre_seq_len = config.pre_seq_len
-            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-
-            with force_non_empty_weights():
-                if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
-                    logger.info(
-                        "Prompt embeddings and their optimizer statistics will be kept in float32 "
-                        "to increase ptune quality"
-                    )
-                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
-                if config.tuning_mode == "deep_ptune":
-                    self.intermediate_prompt_embeddings = nn.Embedding(
-                        self.pre_seq_len,
-                        config.num_hidden_layers * config.hidden_size,
-                        # ^-- TODO: should be num_hidden_layers - 1
-                        dtype=torch.float32,
-                    )
-        elif config.tuning_mode:
-            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
-
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
-        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-
-        if self.config.tuning_mode == "deep_ptune":
-            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
-            intermediate_prompts = intermediate_prompts.view(
-                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
-            )
-            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
-        else:
-            intermediate_prompts = DUMMY
-
-        dtype = self.word_embeddings.weight.dtype
-        return prompts.to(dtype), intermediate_prompts.to(dtype)
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
-    ):
-        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
-
-        for k, v in kwargs.items():
-            if not (v is None or v is False):
-                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
-
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        elif input_ids is not None:
-            input_shape = input_ids.size()
-            input_ids = input_ids.view(-1, input_shape[-1])
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            batch_size = inputs_embeds.shape[0]
-            prompts, intermediate_prompts = self.get_prompt(batch_size)
-            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
-        output_shape = input_shape + (hidden_states.size(-1),)
-
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
-        else:
-            hidden_states = self.h(hidden_states)
-
-        # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = hidden_states[:, self.pre_seq_len :]
-
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-        hidden_states = hidden_states.view(output_shape)
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=None,
-            hidden_states=None,
-            attentions=None,
-        )
-
-
-class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
-    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
-
-    _keys_to_ignore_on_load_missing = (
-        BloomForCausalLM._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-        + [r"^lm_head.word_embeddings\.weight$"]  # Missing since they are shared with input embeddings
-    )
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig):
-        BloomPreTrainedModel.__init__(self, config)
-        self.transformer = DistributedBloomModel(config)
-        self.lm_head = LMHead(config, self.transformer.word_embeddings)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_input_embeddings(self):
-        return self.transformer.word_embeddings
-
-    def get_output_embeddings(self):
-        if self.config.tie_word_embeddings:
-            return None
-        return self.lm_head
-
-    def set_input_embeddings(self, new_embeddings: nn.Embedding):
-        assert isinstance(new_embeddings, nn.Embedding)
-        self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
-        assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
-
-    def set_output_embeddings(self, new_lm_head: nn.Linear):
-        with torch.no_grad():
-            self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
-            self.lm_head.bias[...] = new_lm_head.bias
-
-
-class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
-    _keys_to_ignore_on_load_missing = (
-        BloomForSequenceClassification._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-    )
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig):
-        BloomPreTrainedModel.__init__(self, config)
-        self.num_labels = config.num_labels
-
-        self.transformer = DistributedBloomModel(config)
-        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
-
-        # Initialize weights and apply final processing
-        self.post_init()

+ 3 - 4
src/petals/client/remote_sequential.py

@@ -6,9 +6,8 @@ import torch
 from hivemind import DHT, get_logger
 from hivemind import DHT, get_logger
 from torch import nn
 from torch import nn
 
 
-import petals.client
 from petals.client.inference_session import InferenceSession
 from petals.client.inference_session import InferenceSession
-from petals.client.routing.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
 from petals.data_structures import UID_DELIMITER
 from petals.utils.misc import DUMMY
 from petals.utils.misc import DUMMY
@@ -23,7 +22,7 @@ class RemoteSequential(nn.Module):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        config: petals.client.DistributedBloomConfig,
+        config: SequenceManagerConfig,
         *,
         *,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         dht: Optional[DHT] = None,
         dht: Optional[DHT] = None,
@@ -40,7 +39,7 @@ class RemoteSequential(nn.Module):
             if start_block is None:
             if start_block is None:
                 start_block = 0
                 start_block = 0
             if end_block is None:
             if end_block is None:
-                end_block = self.config.n_layer
+                end_block = self.config.num_hidden_layers
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
             sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
             sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
         self.sequence_manager = sequence_manager
         self.sequence_manager = sequence_manager

+ 8 - 1
src/petals/client/routing/sequence_manager.py

@@ -20,6 +20,7 @@ from hivemind.utils.logging import get_logger
 import petals.dht_utils
 import petals.dht_utils
 from petals.client.routing.sequence_info import RemoteSequenceInfo
 from petals.client.routing.sequence_info import RemoteSequenceInfo
 from petals.client.routing.spending_policy import NoSpendingPolicy
 from petals.client.routing.spending_policy import NoSpendingPolicy
+from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
 
 
@@ -28,6 +29,10 @@ logger = get_logger(__name__)
 
 
 @dataclasses.dataclass
 @dataclasses.dataclass
 class SequenceManagerConfig:
 class SequenceManagerConfig:
+    initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS)  # a list of initial peers for hivemind DHT
+    dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
+    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
+
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
 
 
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
@@ -73,6 +78,8 @@ class RemoteSequenceManager:
         dht: Optional[DHT] = None,
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
         state: Optional[SequenceManagerState] = None,
     ):
     ):
+        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
+        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
         assert len(block_uids) > 0, "Sequences must contain at least one block"
         assert len(block_uids) > 0, "Sequences must contain at least one block"
 
 
         self.config = config
         self.config = config
@@ -84,7 +91,7 @@ class RemoteSequenceManager:
             dht = DHT(
             dht = DHT(
                 initial_peers=config.initial_peers,
                 initial_peers=config.initial_peers,
                 client_mode=True,
                 client_mode=True,
-                num_workers=config.n_layer,
+                num_workers=config.num_hidden_layers,
                 startup_timeout=config.daemon_startup_timeout,
                 startup_timeout=config.daemon_startup_timeout,
                 start=True,
                 start=True,
             )
             )

+ 2 - 0
src/petals/models/__init__.py

@@ -0,0 +1,2 @@
+from petals.models.bloom import *
+from petals.models.llama import *

+ 7 - 0
src/petals/models/bloom/__init__.py

@@ -0,0 +1,7 @@
+from petals.models.bloom.block import WrappedBloomBlock
+from petals.models.bloom.config import DistributedBloomConfig
+from petals.models.bloom.model import (
+    DistributedBloomForCausalLM,
+    DistributedBloomForSequenceClassification,
+    DistributedBloomModel,
+)

+ 32 - 0
src/petals/models/bloom/block.py

@@ -0,0 +1,32 @@
+"""
+Bloom intermediate layer
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
+
+
+class WrappedBloomBlock(BloomBlock):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs
+    ):
+        assert attention_mask is None, "Non-causal attention masks are not supported yet"
+        batch_size, seq_length = hidden_states.shape[:2]
+        past_length = 0 if layer_past is None else layer_past[0].shape[-1]
+        seq_length_with_past = seq_length + past_length
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+        if alibi is None:
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
+        return super().forward(
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
+        )

+ 35 - 0
src/petals/models/bloom/config.py

@@ -0,0 +1,35 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
+
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.client.routing.sequence_manager import SequenceManagerConfig
+from petals.models.bloom.block import WrappedBloomBlock
+from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.version import get_compatible_model_repo
+
+logger = get_logger(__name__)
+
+
+class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedBloomBlock
+    attn_class = BloomAttention
+    block_prefix = "h"
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            # We need "-petals" for backward compatibility with Petals < 1.2.0
+            dht_prefix = str(model_name_or_path) + "-petals"
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+
+
+AutoDistributedConfig.register(DistributedBloomConfig)

+ 134 - 0
src/petals/models/bloom/model.py

@@ -0,0 +1,134 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.bloom.config import DistributedBloomConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
+    """BloomModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = (
+        BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = [r"^h\."]
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.h) == 0
+        config.num_hidden_layers = n_layer
+
+        self.h = RemoteSequential(config, dht=dht)
+
+        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
+
+
+class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
+    _keys_to_ignore_on_load_missing = (
+        BloomForCausalLM._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+        + [r"^lm_head\."]  # Missing since they are shared with input embeddings
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.transformer = DistributedBloomModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+
+class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
+    _keys_to_ignore_on_load_missing = (
+        BloomForSequenceClassification._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
+
+        # Initialize weights and apply final processing
+        self.post_init()

+ 7 - 0
src/petals/models/llama/__init__.py

@@ -0,0 +1,7 @@
+from petals.models.llama.block import WrappedLlamaBlock
+from petals.models.llama.config import DistributedLlamaConfig
+from petals.models.llama.model import (
+    DistributedLlamaForCausalLM,
+    DistributedLlamaForSequenceClassification,
+    DistributedLlamaModel,
+)

+ 87 - 0
src/petals/models/llama/block.py

@@ -0,0 +1,87 @@
+"""
+LLaMA intermediate layer
+Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+See commit history for authorship.
+"""
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+
+
+class WrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+
+        if position_ids is None:
+            device = hidden_states.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+            )
+        attention_mask = LlamaModel._prepare_decoder_attention_mask(
+            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        )
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_llama_to_bloom(
+                present_key_value, batch_size, seq_length_with_past
+            )
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_llama(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_from_llama_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)

+ 35 - 0
src/petals/models/llama/config.py

@@ -0,0 +1,35 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention
+
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.client.routing.sequence_manager import SequenceManagerConfig
+from petals.models.llama.block import WrappedLlamaBlock
+from petals.utils.auto_config import AutoDistributedConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedLlamaBlock
+    attn_class = LlamaAttention
+    block_prefix = "model.layers"
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            dht_prefix = str(model_name_or_path)
+            if "/" in dht_prefix:  # If present, strip repository name to merge blocks hosted by different accounts
+                dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+
+
+AutoDistributedConfig.register(DistributedLlamaConfig)

+ 152 - 0
src/petals/models/llama/model.py

@@ -0,0 +1,152 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.llama.config import DistributedLlamaConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
+    """LlamaModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.layers) == 0
+        config.num_hidden_layers = n_layer
+
+        self.layers = RemoteSequential(config, dht=dht)
+
+        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ) -> BaseModelOutputWithPast:
+        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = inputs_embeds
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.layers(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.norm(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
+
+    @property
+    def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin
+        return self.embed_tokens
+
+    @property
+    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return nn.Identity()
+
+    @property
+    def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin
+        return self.layers
+
+    @property
+    def ln_f(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return self.norm
+
+
+class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
+    _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config: DistributedLlamaConfig):
+        LlamaPreTrainedModel.__init__(self, config)
+        self.model = DistributedLlamaModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    @property
+    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin
+        return self.model
+
+
+class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
+    _keys_to_ignore_on_load_missing = (
+        LlamaForSequenceClassification._keys_to_ignore_on_load_missing
+        + DistributedLlamaModel._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config):
+        LlamaPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.model = DistributedLlamaModel(config)
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @property
+    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin
+        return self.model

+ 11 - 10
src/petals/server/backend.py

@@ -1,4 +1,3 @@
-"""Code for serving bloom blocks via hivemind-server"""
 from __future__ import annotations
 from __future__ import annotations
 
 
 from collections import Counter
 from collections import Counter
@@ -12,8 +11,7 @@ from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 from tensor_parallel import TensorParallel
 from tensor_parallel import TensorParallel
 from tensor_parallel.tensor_parallel import PerDeviceTensors
 from tensor_parallel.tensor_parallel import PerDeviceTensors
-from transformers import BloomConfig
-from transformers.models.bloom.modeling_bloom import BloomAttention
+from transformers import PretrainedConfig
 
 
 from petals.data_structures import InferenceMetadata
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
@@ -24,17 +22,19 @@ logger = get_logger(__name__)
 
 
 
 
 class TransformerBackend(ModuleBackend):
 class TransformerBackend(ModuleBackend):
-    """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
+    """A wrapper for a transformer block that can process requests for forward, backward and inference"""
 
 
-    def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
+    def __init__(
+        self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
+    ):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, TensorParallel)
         assert isinstance(self.module, TensorParallel)
         self.config = config
         self.config = config
         self.memory_cache = memory_cache
         self.memory_cache = memory_cache
         for name, param in self.module.named_parameters():
         for name, param in self.module.named_parameters():
-            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+            assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
         for name, buf in self.module.named_buffers():
         for name, buf in self.module.named_buffers():
-            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+            assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
 
 
         max_batch_size = self.forward_pool.max_batch_size
         max_batch_size = self.forward_pool.max_batch_size
         device = self.module.devices[self.module.output_device_index]
         device = self.module.devices[self.module.output_device_index]
@@ -52,9 +52,10 @@ class TransformerBackend(ModuleBackend):
         self.shard_num_heads = []
         self.shard_num_heads = []
         for shard in self.module.module_shards:
         for shard in self.module.module_shards:
             for submodule in shard.modules():
             for submodule in shard.modules():
-                if isinstance(submodule, BloomAttention):
+                if isinstance(submodule, config.attn_class):
                     self.shard_num_heads.append(submodule.num_heads)
                     self.shard_num_heads.append(submodule.num_heads)
-        assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
+        assert len(self.shard_num_heads) == len(self.module.devices)
+        assert sum(self.shard_num_heads) == config.num_attention_heads
 
 
         self.inference_schema = (
         self.inference_schema = (
             (
             (
@@ -71,7 +72,7 @@ class TransformerBackend(ModuleBackend):
 
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""
         """Create tensor descriptors for attention cache tensors used during inference_step"""
-        head_dim = self.config.hidden_size // self.config.n_head
+        head_dim = self.config.hidden_size // self.config.num_attention_heads
         cache_tensors = []
         cache_tensors = []
         for device, num_heads in zip(self.module.devices, self.shard_num_heads):
         for device, num_heads in zip(self.module.devices, self.shard_num_heads):
             keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
             keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)

+ 4 - 6
src/petals/server/block_utils.py

@@ -2,12 +2,10 @@ from typing import Optional, Union
 
 
 import torch
 import torch
 from accelerate import init_empty_weights
 from accelerate import init_empty_weights
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
 
-from petals.bloom.block import WrappedBloomBlock
 
 
-
-def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
+def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     if dtype not in ("auto", None):
     if dtype not in ("auto", None):
         return dtype
         return dtype
@@ -17,7 +15,7 @@ def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) ->
 
 
 
 
 def get_block_size(
 def get_block_size(
-    config: BloomConfig,
+    config: PretrainedConfig,
     location: str,
     location: str,
     *,
     *,
     dtype: Optional[Union[str, torch.dtype]] = None,
     dtype: Optional[Union[str, torch.dtype]] = None,
@@ -30,7 +28,7 @@ def get_block_size(
         ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
         ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
 
 
     with init_empty_weights(include_buffers=True):
     with init_empty_weights(include_buffers=True):
-        block = WrappedBloomBlock(config)
+        block = config.block_class(config)
         n_params = sum(param.numel() for param in block.parameters())
         n_params = sum(param.numel() for param in block.parameters())
 
 
     if location == "memory" and load_in_8bit:
     if location == "memory" and load_in_8bit:

+ 175 - 0
src/petals/server/from_pretrained.py

@@ -0,0 +1,175 @@
+"""
+Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+If necessary, one can rewrite this to implement a different behavior, such as:
+ - loading files from a local data source (e.g. S3)
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
+
+"""
+import json
+import time
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from hivemind.utils.logging import get_logger
+from huggingface_hub import get_hf_file_metadata, hf_hub_url
+from transformers import PretrainedConfig
+from transformers.utils import get_file_from_repo
+
+from petals.server.block_utils import resolve_block_dtype
+from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
+
+logger = get_logger(__name__)
+
+
+def load_pretrained_block(
+    model_name: str,
+    block_index: int,
+    *,
+    config: Optional[PretrainedConfig] = None,
+    torch_dtype: Union[torch.dtype, str] = "auto",
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
+    max_disk_space: Optional[int] = None,
+) -> nn.Module:
+    if config is None:
+        config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+
+    assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+    torch_dtype = resolve_block_dtype(config, torch_dtype)
+
+    with init_empty_weights():
+        block = config.block_class(config)
+
+    block_prefix = f"{config.block_prefix}.{block_index}."
+    state_dict = _load_state_dict_from_repo(
+        model_name,
+        block_prefix,
+        revision=revision,
+        use_auth_token=use_auth_token,
+        cache_dir=cache_dir,
+        max_disk_space=max_disk_space,
+    )
+
+    # dummy load, check that keys match
+    report = block.load_state_dict(state_dict, strict=True)
+    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
+
+    for param_name, _ in block.named_parameters():
+        assert param_name in state_dict, f"{param_name} not in state dict"
+        param = state_dict[param_name]
+        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+            param = param.to(torch_dtype)
+        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
+
+    logger.info(f"Loaded {model_name} block {block_index}, {report}")
+    return block
+
+
+StateDict = Dict[str, torch.Tensor]
+
+
+def _load_state_dict_from_repo(
+    model_name: str,
+    block_prefix: str,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+) -> StateDict:
+    index_file = get_file_from_repo(
+        model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
+    if index_file is not None:  # Sharded model
+        with open(index_file) as f:
+            index = json.load(f)
+        filenames = {
+            filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
+        }
+        if not filenames:
+            raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
+    else:  # Non-sharded model
+        filenames = {"pytorch_model.bin"}
+    logger.debug(f"Loading {block_prefix}* from {filenames}")
+
+    state_dict = {}
+    for filename in filenames:
+        shard_state_dict = _load_state_dict_from_file(
+            model_name,
+            filename,
+            revision=revision,
+            use_auth_token=use_auth_token,
+            cache_dir=cache_dir,
+            max_disk_space=max_disk_space,
+        )
+        shard_state_dict = {
+            param_name[len(block_prefix) :]: param
+            for param_name, param in shard_state_dict.items()
+            if param_name.startswith(block_prefix)
+        }  # Remove unused parameters from memory
+        state_dict.update(shard_state_dict)
+    return state_dict
+
+
+def _load_state_dict_from_file(
+    model_name: str,
+    filename: str,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    delay: float = 30,
+) -> StateDict:
+    # First, try to find the weights locally
+    try:
+        with allow_cache_reads(cache_dir):
+            path = get_file_from_repo(
+                model_name,
+                filename,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            if path is not None:
+                return torch.load(path, map_location="cpu")
+    except Exception:
+        logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
+
+    # If not found, ensure that we have enough disk space to download them (maybe remove something)
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                url = hf_hub_url(model_name, filename, revision=revision)
+                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                if file_size is not None:
+                    free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
+
+                path = get_file_from_repo(
+                    model_name,
+                    filename,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+                if path is None:
+                    raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
+                return torch.load(path, map_location="cpu")
+        except Exception as e:
+            logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)
+
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 35 - 29
src/petals/server/server.py

@@ -14,21 +14,23 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
 
-from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
 from petals.server.block_utils import get_block_size, resolve_block_dtype
+from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.server.throughput import get_dtype_name, get_server_throughput
+from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.utils.version import get_compatible_model_repo
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -53,7 +55,7 @@ class Server:
         max_batch_size: int = 2048,
         max_batch_size: int = 2048,
         inference_max_length: int = 2048,
         inference_max_length: int = 2048,
         torch_dtype: str = "auto",
         torch_dtype: str = "auto",
-        revision: str = "main",
+        revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
         max_disk_space: Optional[int] = None,
         attn_cache_tokens: int = 8192,
         attn_cache_tokens: int = 8192,
@@ -83,25 +85,32 @@ class Server:
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
 
 
+        converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)
         self.converted_model_name_or_path = converted_model_name_or_path
         self.converted_model_name_or_path = converted_model_name_or_path
+
         self.num_handlers = num_handlers
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.inference_max_length = inference_max_length
         self.compression = compression
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
-        self.use_auth_token = use_auth_token
+        self.revision, self.use_auth_token = revision, use_auth_token
 
 
         if custom_module_path is not None:
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
             add_custom_models_from_file(custom_module_path)
 
 
+        self.block_config = AutoDistributedConfig.from_pretrained(
+            converted_model_name_or_path,
+            use_auth_token=use_auth_token,
+            revision=revision,
+        )
+
         if prefix is None:
         if prefix is None:
-            prefix = converted_model_name_or_path
-            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
-                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
-                f"Please specify --prefix manually when starting a server"
-            )
-            logger.debug(f"Automatic dht prefix: {prefix}")
+            prefix = self.block_config.dht_prefix
+        assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+            f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
+            f"Please specify another --prefix manually when starting a server"
+        )
         self.prefix = prefix
         self.prefix = prefix
 
 
         if expiration is None:
         if expiration is None:
@@ -111,12 +120,9 @@ class Server:
         self.request_timeout = request_timeout
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
 
-        self.block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path,
-            use_auth_token=use_auth_token,
-            revision=revision,
-        )
-        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+        self.module_uids = [
+            f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
+        ]
 
 
         if dht_client_mode is None:
         if dht_client_mode is None:
             is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
             is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
@@ -125,7 +131,7 @@ class Server:
         self.dht = DHT(
         self.dht = DHT(
             initial_peers=initial_peers,
             initial_peers=initial_peers,
             start=True,
             start=True,
-            num_workers=self.block_config.n_layer,
+            num_workers=self.block_config.num_hidden_layers,
             use_relay=use_relay,
             use_relay=use_relay,
             use_auto_relay=use_auto_relay,
             use_auto_relay=use_auto_relay,
             client_mode=dht_client_mode,
             client_mode=dht_client_mode,
@@ -161,10 +167,10 @@ class Server:
         if load_in_8bit is None:
         if load_in_8bit is None:
             load_in_8bit = device.type == "cuda"
             load_in_8bit = device.type == "cuda"
         self.load_in_8bit = load_in_8bit
         self.load_in_8bit = load_in_8bit
-        logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
+        logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
 
 
-        max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
-        self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
+        cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
+        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
 
 
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
         if num_blocks is None and block_indices is None:
@@ -192,6 +198,7 @@ class Server:
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
         if throughput in ["auto", "eval"]:
             throughput = get_server_throughput(
             throughput = get_server_throughput(
+                converted_model_name_or_path,
                 self.block_config,
                 self.block_config,
                 device,
                 device,
                 torch_dtype,
                 torch_dtype,
@@ -239,11 +246,12 @@ class Server:
         num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
         num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
 
+        num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
         logger.info(
         logger.info(
             f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
             f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
             f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
             f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
         )
         )
-        return min(num_blocks, self.block_config.n_layer)
+        return num_blocks
 
 
     def run(self):
     def run(self):
         while True:
         while True:
@@ -274,6 +282,7 @@ class Server:
                 step_timeout=self.step_timeout,
                 step_timeout=self.step_timeout,
                 prefetch_batches=self.prefetch_batches,
                 prefetch_batches=self.prefetch_batches,
                 sender_threads=self.sender_threads,
                 sender_threads=self.sender_threads,
+                revision=self.revision,
                 use_auth_token=self.use_auth_token,
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
                 load_in_8bit=self.load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 tensor_parallel_devices=self.tensor_parallel_devices,
@@ -352,7 +361,7 @@ class ModuleContainer(threading.Thread):
         dht: DHT,
         dht: DHT,
         prefix: str,
         prefix: str,
         converted_model_name_or_path: str,
         converted_model_name_or_path: str,
-        block_config: BloomConfig,
+        block_config: PretrainedConfig,
         attn_cache_bytes: int,
         attn_cache_bytes: int,
         alloc_timeout: float,
         alloc_timeout: float,
         throughput: float,
         throughput: float,
@@ -366,6 +375,7 @@ class ModuleContainer(threading.Thread):
         compression: CompressionType,
         compression: CompressionType,
         update_period: float,
         update_period: float,
         expiration: Optional[float],
         expiration: Optional[float],
+        revision: Optional[str],
         use_auth_token: Optional[str],
         use_auth_token: Optional[str],
         load_in_8bit: bool,
         load_in_8bit: bool,
         tensor_parallel_devices: Sequence[torch.device],
         tensor_parallel_devices: Sequence[torch.device],
@@ -394,14 +404,14 @@ class ModuleContainer(threading.Thread):
                 block = load_pretrained_block(
                 block = load_pretrained_block(
                     converted_model_name_or_path,
                     converted_model_name_or_path,
                     block_index,
                     block_index,
-                    block_config,
+                    config=block_config,
                     torch_dtype=torch_dtype,
                     torch_dtype=torch_dtype,
+                    revision=revision,
                     use_auth_token=use_auth_token,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                     max_disk_space=max_disk_space,
                 )
                 )
                 block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
                 block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
-
                 blocks[module_uid] = TransformerBackend(
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     module_uid,
                     block,
                     block,
@@ -564,13 +574,9 @@ class ModuleContainer(threading.Thread):
 
 
         self.ready.clear()
         self.ready.clear()
 
 
+        logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
         for handler in self.conn_handlers:
             handler.shutdown()
             handler.shutdown()
-        logger.debug("Connection handlers terminated")
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
 
 
         logger.debug(f"Shutting down pools")
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools:
         for pool in self.runtime.pools:

+ 12 - 10
src/petals/server/throughput.py

@@ -5,15 +5,13 @@ import multiprocessing as mp
 import os
 import os
 import time
 import time
 from collections import Counter
 from collections import Counter
-from hashlib import sha256
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 from typing import Dict, Optional, Sequence, Union
 
 
 import torch
 import torch
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
 
-from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import convert_block
 from petals.utils.convert_block import convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@@ -35,7 +33,8 @@ if not hasattr(speedtest, "Speedtest"):
 
 
 
 
 def get_server_throughput(
 def get_server_throughput(
-    config: BloomConfig,
+    model_name: str,
+    config: PretrainedConfig,
     device: torch.device,
     device: torch.device,
     dtype: Union[str, torch.dtype],
     dtype: Union[str, torch.dtype],
     *,
     *,
@@ -59,7 +58,7 @@ def get_server_throughput(
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
         # The OS will release the lock when lock_fd is closed or the process is killed
 
 
-        cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
+        cache_key = f"model_{model_name}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
         if len(tensor_parallel_devices) > 1:
         if len(tensor_parallel_devices) > 1:
@@ -101,7 +100,7 @@ def get_server_throughput(
 
 
 
 
 def measure_throughput_info(
 def measure_throughput_info(
-    config: BloomConfig,
+    config: PretrainedConfig,
     device: torch.device,
     device: torch.device,
     dtype: torch.dtype,
     dtype: torch.dtype,
     *,
     *,
@@ -127,7 +126,7 @@ def measure_throughput_info(
     return throughput_info
     return throughput_info
 
 
 
 
-def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]:
+def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
     pipe_recv, pipe_send = mp.Pipe(duplex=False)
     pipe_recv, pipe_send = mp.Pipe(duplex=False)
     process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
     process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
     process.start()
     process.start()
@@ -160,7 +159,7 @@ def _measure_bits_per_second(pipe_send: mp.Pipe):
 
 
 
 
 def measure_compute_rps(
 def measure_compute_rps(
-    config: BloomConfig,
+    config: PretrainedConfig,
     device: torch.device,
     device: torch.device,
     dtype: torch.dtype,
     dtype: torch.dtype,
     *,
     *,
@@ -172,7 +171,7 @@ def measure_compute_rps(
     if not tensor_parallel_devices:
     if not tensor_parallel_devices:
         tensor_parallel_devices = (device,)
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
     with torch.inference_mode():
-        block = WrappedBloomBlock(config).to(dtype)
+        block = config.block_class(config).to(dtype)
         block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
         block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
 
 
         cache = None
         cache = None
@@ -203,4 +202,7 @@ def get_device_name(device: torch.device) -> str:
 
 
 
 
 def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
 def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
-    return "8-bit" if load_in_8bit else str(dtype)
+    name = str(dtype)
+    if load_in_8bit:
+        name += ", 8-bit quantized"
+    return name

+ 1 - 0
src/petals/utils/__init__.py

@@ -0,0 +1 @@
+from petals.utils.auto_config import AutoDistributedConfig

+ 23 - 0
src/petals/utils/auto_config.py

@@ -0,0 +1,23 @@
+from typing import Type
+
+from transformers import AutoConfig, PretrainedConfig
+
+CONFIG_MAPPING = {}  # Populated with AutoDistributedConfig.register()
+
+
+class AutoDistributedConfig:
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
+        config = AutoConfig.from_pretrained(*args, **kwargs)
+        if config.model_type not in CONFIG_MAPPING:
+            raise ValueError(f"Petals does not support model type {config.model_type}")
+
+        dist_config_class = CONFIG_MAPPING[config.model_type]
+        return dist_config_class.from_pretrained(*args, **kwargs)
+
+    @staticmethod
+    def register(config_class: Type[PretrainedConfig]) -> None:
+        assert issubclass(config_class, PretrainedConfig)
+        assert config_class.model_type not in CONFIG_MAPPING
+
+        CONFIG_MAPPING[config_class.model_type] = config_class

+ 15 - 13
src/petals/utils/convert_block.py

@@ -10,18 +10,15 @@ import torch
 import torch.nn as nn
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tensor_parallel.slicing_configs import get_bloom_config
 from tensor_parallel.slicing_configs import get_bloom_config
-from transformers import BloomConfig
-from transformers.models.bloom.modeling_bloom import BloomAttention
-
-from petals.bloom.block import WrappedBloomBlock
+from transformers import PretrainedConfig
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
 def convert_block(
 def convert_block(
-    block: WrappedBloomBlock,
-    config: BloomConfig,
+    block: nn.Module,
+    config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     output_device: torch.device,
     load_in_8bit: bool,
     load_in_8bit: bool,
@@ -58,7 +55,7 @@ def convert_block(
     return block
     return block
 
 
 
 
-def replace_8bit_linear(model: nn.Module, threshold=6.0):
+def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module:
     """
     """
     A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
     A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
     library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
     library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
@@ -100,17 +97,22 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
 
 
 
 
 def make_tensor_parallel(
 def make_tensor_parallel(
-    block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
-):
-    tp_config = get_bloom_config(model_config, devices)
-    del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
+) -> nn.Module:
+    if model_config.model_type == "bloom":
+        tp_config = get_bloom_config(model_config, devices)
+        del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    else:
+        if len(devices) > 1:
+            logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
+        tp_config = None
     tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
     tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
     total_heads = 0
     total_heads = 0
     for tp_shard in tp_block.module_shards:
     for tp_shard in tp_block.module_shards:
         for submodule in tp_shard.modules():
         for submodule in tp_shard.modules():
-            if isinstance(submodule, BloomAttention):
+            if isinstance(submodule, model_config.attn_class):
                 total_heads += submodule.num_heads
                 total_heads += submodule.num_heads
-    assert total_heads == model_config.n_head
+    assert total_heads == model_config.num_attention_heads
     return tp_block
     return tp_block
 
 
 
 

+ 5 - 3
src/petals/utils/disk_cache.py

@@ -57,13 +57,16 @@ def free_disk_space_for(
     available_space = shutil.disk_usage(cache_dir).free - os_quota
     available_space = shutil.disk_usage(cache_dir).free - os_quota
     if max_disk_space is not None:
     if max_disk_space is not None:
         available_space = min(available_space, max_disk_space - occupied_space)
         available_space = min(available_space, max_disk_space - occupied_space)
+
+    gib = 1024**3
+    logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB")
     if size <= available_space:
     if size <= available_space:
         return
         return
 
 
     revisions = [revision for repo in model_repos for revision in repo.revisions]
     revisions = [revision for repo in model_repos for revision in repo.revisions]
     revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
     revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
 
 
-    # Remove as few least recently used blocks as possible
+    # Remove as few least recently used shards as possible
     pending_removal = []
     pending_removal = []
     freed_space = 0
     freed_space = 0
     extra_space_needed = size - available_space
     extra_space_needed = size - available_space
@@ -73,9 +76,8 @@ def free_disk_space_for(
         if freed_space >= extra_space_needed:
         if freed_space >= extra_space_needed:
             break
             break
 
 
-    gib = 1024**3
     if pending_removal:
     if pending_removal:
-        logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
+        logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space")
         delete_strategy = cache_info.delete_revisions(*pending_removal)
         delete_strategy = cache_info.delete_revisions(*pending_removal)
         delete_strategy.execute()
         delete_strategy.execute()
 
 

+ 19 - 1
src/petals/utils/version.py

@@ -1,3 +1,7 @@
+import os
+import re
+from typing import Union
+
 import requests
 import requests
 from hivemind.utils.logging import TextStyle, get_logger
 from hivemind.utils.logging import TextStyle, get_logger
 from packaging.version import parse
 from packaging.version import parse
@@ -7,7 +11,7 @@ import petals
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-def validate_version():
+def validate_version() -> None:
     logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
     logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
     try:
     try:
         r = requests.get("https://pypi.python.org/pypi/petals/json")
         r = requests.get("https://pypi.python.org/pypi/petals/json")
@@ -24,3 +28,17 @@ def validate_version():
             )
             )
     except Exception as e:
     except Exception as e:
         logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
         logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
+
+
+def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
+    if model_name_or_path is None:
+        return None
+
+    match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
+    if match is None:
+        return model_name_or_path
+
+    logger.info(
+        f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
+    )
+    return match.group(1)

+ 2 - 2
tests/test_aux_functions.py

@@ -1,7 +1,7 @@
 import pytest
 import pytest
 import torch
 import torch
 
 
-from petals.client import DistributedBloomConfig
+from petals import AutoDistributedConfig
 from petals.server.throughput import measure_compute_rps
 from petals.server.throughput import measure_compute_rps
 from test_utils import MODEL_NAME
 from test_utils import MODEL_NAME
 
 
@@ -9,7 +9,7 @@ from test_utils import MODEL_NAME
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.parametrize("tensor_parallel", [False, True])
 @pytest.mark.parametrize("tensor_parallel", [False, True])
 def test_compute_throughput(tensor_parallel: bool):
 def test_compute_throughput(tensor_parallel: bool):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
     tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
     compute_rps = measure_compute_rps(
         config,
         config,

+ 12 - 58
tests/test_block_exact_match.py

@@ -1,13 +1,10 @@
 import random
 import random
-from typing import Union
 
 
 import pytest
 import pytest
 import torch
 import torch
-from transformers.models.bloom.configuration_bloom import BloomConfig
 
 
-from petals.bloom.block import WrappedBloomBlock
-from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
-from petals.client import DistributedBloomConfig, RemoteSequential
+from petals import DistributedBloomConfig, RemoteSequential
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 from test_utils import *
 
 
 
 
@@ -16,21 +13,22 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     remote_sequential = RemoteSequential(config)
     remote_sequential = RemoteSequential(config)
 
 
-    for block_index in random.sample(range(config.n_layer), 3):
+    for block_index in random.sample(range(config.num_hidden_layers), 3):
         remote_block = remote_sequential[block_index]
         remote_block = remote_sequential[block_index]
 
 
         inputs = torch.randn(1, 8, config.hidden_size)
         inputs = torch.randn(1, 8, config.hidden_size)
         outputs_forward = remote_block(inputs)
         outputs_forward = remote_block(inputs)
 
 
         outputs_inference = []
         outputs_inference = []
-        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
-            for i in range(inputs.shape[1]):
-                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-
-            # test that max length is respected
-            with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
-                sess.step(inputs[:, -1:, :])
-            assert "Maximum length exceeded" in repr(exc_info.value)
+        with torch.inference_mode():
+            with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+                for i in range(inputs.shape[1]):
+                    outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+
+                # test that max length is respected
+                with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
+                    sess.step(inputs[:, -1:, :])
+                assert "Maximum length exceeded" in repr(exc_info.value)
         outputs_inference = torch.cat(outputs_inference, dim=1)
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
@@ -38,47 +36,3 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
 
 
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
-
-
-def _old_load_pretrained_block(
-    converted_model_name_or_path: str,
-    block_index: int,
-    torch_dtype: Union[torch.dtype, str] = "auto",
-) -> WrappedBloomBlock:
-    """Load the BLOOM block by directly initializing the weights.
-    This test is used to check consistency with the previous implementation and can be removed in the future."""
-    config = BloomConfig.from_pretrained(converted_model_name_or_path)
-
-    block = WrappedBloomBlock(config)
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        cache_dir=None,
-    )
-
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
-    block.load_state_dict(state_dict, strict=True)
-    return block
-
-
-@pytest.mark.forked
-def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    torch.random.manual_seed(0)
-    inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
-
-    block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
-    ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
-
-    outputs = block.forward(inputs)[0]
-    outputs_ref = ref_block.forward(inputs)[0]
-    assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)

+ 2 - 2
tests/test_chained_calls.py

@@ -7,9 +7,9 @@
 import pytest
 import pytest
 import torch
 import torch
 
 
-from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import DistributedBloomConfig
+from petals import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.remote_sequential import RemoteSequential
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 from test_utils import *
 
 
 
 

+ 7 - 8
tests/test_dtype.py

@@ -1,17 +1,16 @@
 import pytest
 import pytest
 import torch
 import torch
 
 
-from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import DistributedBloomConfig
 from petals.server.block_utils import resolve_block_dtype
 from petals.server.block_utils import resolve_block_dtype
+from petals.server.from_pretrained import load_pretrained_block
+from petals.utils.auto_config import AutoDistributedConfig
 from test_utils import MODEL_NAME
 from test_utils import MODEL_NAME
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
 @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
-def test_backend_dtype(torch_dtype):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
-    backend_dtype = resolve_block_dtype(config, torch_dtype)
-    other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
-    assert backend_dtype == other_backend_dtype
+def test_block_dtype(torch_dtype):
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
+    block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
+    expected_dtype = resolve_block_dtype(config, torch_dtype)
+    assert all(param.dtype == expected_dtype for param in block.parameters())

+ 2 - 2
tests/test_full_model.py

@@ -5,7 +5,7 @@ from hivemind import get_logger
 from transformers.generation import BeamSearchScorer
 from transformers.generation import BeamSearchScorer
 from transformers.models.bloom import BloomForCausalLM
 from transformers.models.bloom import BloomForCausalLM
 
 
-from petals.client.remote_model import DistributedBloomForCausalLM
+from petals import DistributedBloomForCausalLM
 from test_utils import *
 from test_utils import *
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -20,7 +20,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
     )
     )
     config = model.config
     config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
     assert isinstance(model, DistributedBloomForCausalLM)
-    assert len(model.transformer.h) == model.config.n_layer
+    assert len(model.transformer.h) == model.config.num_hidden_layers
 
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
 

+ 10 - 8
tests/test_remote_sequential.py

@@ -4,10 +4,10 @@ import torch.nn.functional as F
 from hivemind import DHT, BatchTensorDescriptor, get_logger
 from hivemind import DHT, BatchTensorDescriptor, get_logger
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
-from petals.bloom.from_pretrained import load_pretrained_block
+from petals import DistributedBloomConfig
 from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client import RemoteSequenceManager, RemoteSequential
-from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 from petals.data_structures import UID_DELIMITER
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 from test_utils import *
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -28,10 +28,10 @@ def test_remote_sequential():
     full_grad = test_inputs.grad.clone()
     full_grad = test_inputs.grad.clone()
     test_inputs.grad.data.zero_()
     test_inputs.grad.data.zero_()
 
 
-    first_half = sequential[: config.n_layer // 2]
-    second_half = sequential[config.n_layer // 2 :]
+    first_half = sequential[: config.num_hidden_layers // 2]
+    second_half = sequential[config.num_hidden_layers // 2 :]
     assert len(first_half) + len(second_half) == len(sequential)
     assert len(first_half) + len(second_half) == len(sequential)
-    assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+    assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2
     for m in sequential, first_half, second_half:
     for m in sequential, first_half, second_half:
         assert isinstance(repr(m), str)
         assert isinstance(repr(m), str)
 
 
@@ -46,7 +46,7 @@ def test_remote_sequential():
     assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
     assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
 
 
     # test RemoteSequential with lossy compression
     # test RemoteSequential with lossy compression
-    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
     lossy_sequential = RemoteSequential(
     lossy_sequential = RemoteSequential(
         config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
         config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
     )
     )
@@ -90,7 +90,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
     input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
     input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
-    intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    intermediate_prompts = torch.randn(
+        config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True
+    )
 
 
     input_prompts = input_prompts.detach().requires_grad_(True)
     input_prompts = input_prompts.detach().requires_grad_(True)
     intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
     intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
@@ -110,7 +112,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     assert intermediate_prompts_ref.grad is None
     assert intermediate_prompts_ref.grad is None
 
 
     outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
     outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
-    for block_index in range(config.n_layer):
+    for block_index in range(config.num_hidden_layers):
         block_prompt = intermediate_prompts_ref[block_index]
         block_prompt = intermediate_prompts_ref[block_index]
         outputs_ref[:, : block_prompt.shape[1]] += block_prompt
         outputs_ref[:, : block_prompt.shape[1]] += block_prompt
 
 

+ 2 - 2
tests/test_sequence_manager.py

@@ -5,8 +5,8 @@ import pytest
 import torch
 import torch
 from hivemind import DHT, get_logger
 from hivemind import DHT, get_logger
 
 
+from petals import DistributedBloomConfig
 from petals.client import RemoteSequenceManager, RemoteSequential
 from petals.client import RemoteSequenceManager, RemoteSequential
-from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 from petals.data_structures import UID_DELIMITER
 from test_utils import *
 from test_utils import *
 
 
@@ -22,7 +22,7 @@ def test_sequence_manager_basics(mode: str):
     shutdown_evt = threading.Event()
     shutdown_evt = threading.Event()
 
 
     # test RemoteSequential with lossy compression
     # test RemoteSequential with lossy compression
-    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
     sequential = RemoteSequential(
     sequential = RemoteSequential(
         config,
         config,
         sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
         sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),

+ 1 - 1
tests/test_server_stats.py

@@ -4,7 +4,7 @@ import hivemind
 import pytest
 import pytest
 import torch
 import torch
 
 
-from petals.client import DistributedBloomConfig, RemoteSequential
+from petals import DistributedBloomConfig, RemoteSequential
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 from test_utils import *
 
 

+ 1 - 1
tests/test_tensor_parallel.py

@@ -6,7 +6,7 @@ import transformers
 from tensor_parallel import TensorParallel
 from tensor_parallel import TensorParallel
 from tensor_parallel.slicing_configs import get_bloom_config
 from tensor_parallel.slicing_configs import get_bloom_config
 
 
-from petals.bloom.from_pretrained import load_pretrained_block
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import MODEL_NAME
 from test_utils import MODEL_NAME