5
0
Эх сурвалжийг харах

Add LLaMA support (#323)

This PR:

1. **Abolishes the model conversion procedure.** Now, models are downloaded directly from original repositories like https://huggingface.co/bigscience/bloom. Servers download only shards with blocks to be hosted, and clients download only shards with input/output embeddings and layernorms.

    - BLOOM is loaded from `bigscience/bloom`, but we use the DHT prefix `bigscience/bloom-petals` for backward compatibility. Same with smaller BLOOMs and BLOOMZ.
    - LLaMA can be loaded from any repo like `username/llama-65b-hf`, but we use the DHT prefix `llama-65b-hf` (without the username) to accomodate blocks from different repos (there're a few of them with minor differences, such as `Llama` vs. `LLaMA` in the class name).

2. **Refactors the client to generalize it for multiple models.** Now, we have `petals.models` packages that contain model-specific code (e.g. `petals.models.bloom`, `petals.models.llama`). General code (e.g. CPU-efficient LM head, p-tuning) is kept in `petals.client`.

3. **Introduces** `WrappedLlamaBlock`, `DistributedLlamaConfig`, `DistributedLlamaForCausalLM`, `DistributedLlamaForSequenceClassification`, and `DistributedLlamaModel` compatible with Petals functionality (p-tuning, adapters, etc.).

4. **Introduces** `AutoDistributedConfig` that automatically chooses the correct config class (`DistributedLlamaConfig` or `DistributedBloomConfig`). The refactored configs contain all model-specific info for both clients and servers.

Upgrade instructions:

- Remove disk caches for blocks in old (converted) format to save disk space. That is, remove `~/.cache/petals/model--bigscience--bloom-petals` and  `~/.cache/petals/model--bigscience--bloomz-petals` directories (if present).
Alexander Borzunov 2 жил өмнө
parent
commit
cb3f018f9f
45 өөрчлөгдсөн 1073 нэмэгдсэн , 853 устгасан
  1. 6 60
      .github/workflows/run-tests.yaml
  2. 3 1
      setup.cfg
  3. 11 1
      src/petals/__init__.py
  4. 0 0
      src/petals/bloom/__init__.py
  5. 0 62
      src/petals/bloom/block.py
  6. 0 132
      src/petals/bloom/from_pretrained.py
  7. 0 20
      src/petals/cli/config.json
  8. 0 96
      src/petals/cli/convert_model.py
  9. 1 1
      src/petals/cli/inference_one_block.py
  10. 1 1
      src/petals/cli/run_server.py
  11. 0 6
      src/petals/client/__init__.py
  12. 94 0
      src/petals/client/from_pretrained.py
  13. 28 44
      src/petals/client/lm_head.py
  14. 88 0
      src/petals/client/ptune.py
  15. 0 268
      src/petals/client/remote_model.py
  16. 3 4
      src/petals/client/remote_sequential.py
  17. 8 1
      src/petals/client/routing/sequence_manager.py
  18. 2 0
      src/petals/models/__init__.py
  19. 7 0
      src/petals/models/bloom/__init__.py
  20. 32 0
      src/petals/models/bloom/block.py
  21. 35 0
      src/petals/models/bloom/config.py
  22. 134 0
      src/petals/models/bloom/model.py
  23. 7 0
      src/petals/models/llama/__init__.py
  24. 87 0
      src/petals/models/llama/block.py
  25. 35 0
      src/petals/models/llama/config.py
  26. 152 0
      src/petals/models/llama/model.py
  27. 11 10
      src/petals/server/backend.py
  28. 4 6
      src/petals/server/block_utils.py
  29. 175 0
      src/petals/server/from_pretrained.py
  30. 35 29
      src/petals/server/server.py
  31. 12 10
      src/petals/server/throughput.py
  32. 1 0
      src/petals/utils/__init__.py
  33. 23 0
      src/petals/utils/auto_config.py
  34. 15 13
      src/petals/utils/convert_block.py
  35. 5 3
      src/petals/utils/disk_cache.py
  36. 19 1
      src/petals/utils/version.py
  37. 2 2
      tests/test_aux_functions.py
  38. 12 58
      tests/test_block_exact_match.py
  39. 2 2
      tests/test_chained_calls.py
  40. 7 8
      tests/test_dtype.py
  41. 2 2
      tests/test_full_model.py
  42. 10 8
      tests/test_remote_sequential.py
  43. 2 2
      tests/test_sequence_manager.py
  44. 1 1
      tests/test_server_stats.py
  45. 1 1
      tests/test_tensor_parallel.py

+ 6 - 60
.github/workflows/run-tests.yaml

@@ -6,57 +6,8 @@ on:
   pull_request:
 
 jobs:
-  convert-model:
-    runs-on: ubuntu-latest
-    env:
-      BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
-    timeout-minutes: 15
-    steps:
-      - name: Checkout
-        uses: actions/checkout@v3
-      - name: Check if the model is cached
-        id: cache-model
-        uses: actions/cache@v3
-        with:
-          path: ~/converted_ok
-          key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-      - name: Set up Python
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        uses: actions/setup-python@v3
-        with:
-          python-version: 3.9
-      - name: Cache dependencies
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        uses: actions/cache@v3
-        with:
-          path: ~/.cache/pip
-          key: Key-v1-3.9-${{ hashFiles('setup.cfg') }}
-      - name: Install dependencies
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          python -m pip install --upgrade pip
-          pip install .
-      - name: Delete any test models older than 1 week
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
-      - name: Delete previous version of this model, if exists
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
-          python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
-          repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true
-      - name: Convert model and push to hub
-        if: steps.cache-model.outputs.cache-hit != 'true'
-        run: |
-          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-          python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
-            --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
-            --resize_token_embeddings 50000 && touch ~/converted_ok
-
   run-tests:
     runs-on: ubuntu-latest
-    needs: convert-model
     strategy:
       matrix:
         python-version: [ '3.7', '3.8', '3.9', '3.10' ]
@@ -80,8 +31,7 @@ jobs:
           pip install .[dev]
       - name: Test
         run: |
-          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
-          export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG
+          export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
@@ -104,23 +54,19 @@ jobs:
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 &> server3.log &
           SERVER3_PID=$!
 
-          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:14 \
-            --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
-          SERVER4_PID=$!
-
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
-            --initial_peers $INITIAL_PEERS --throughput 1 --tensor_parallel_devices cpu cpu  --torch_dtype float32 &> server5.log &
-          SERVER5_PID=$!
+            --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
+          SERVER4_PID=$!
 
           tail -n 100 -f server*.log &
           LOGGER_PID=$!
           sleep 30  # wait for servers to download layers
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
 
           pytest tests --durations=0 --durations-min=1.0 -v
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived tests
 
-          kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID
+          kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $LOGGER_PID
           echo "Done!"

+ 3 - 1
setup.cfg

@@ -35,7 +35,8 @@ install_requires =
     bitsandbytes==0.38.0.post2
     accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
-    transformers>=4.25.1,<5.0.0
+    tokenizers>=0.13.3
+    transformers>=4.30.1,<5.0.0
     speedtest-cli==2.1.3
     hivemind==1.1.8
     tensor_parallel==1.0.23
@@ -43,6 +44,7 @@ install_requires =
     async-timeout>=4.0.2
     cpufeature>=0.2.0
     packaging>=20.9
+    sentencepiece>=0.1.99
 
 [options.extras_require]
 dev =

+ 11 - 1
src/petals/__init__.py

@@ -1,11 +1,21 @@
 import os
 
 import hivemind
+import transformers
+from packaging import version
 
 from petals.client import *
+from petals.models import *
+from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "1.1.5"
+__version__ = "1.2.0.dev0"
+
+
+if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
+    assert (
+        version.parse("4.30.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.30.1,<5.0.0"
 
 
 def _override_bfloat16_mode_default():

+ 0 - 0
src/petals/bloom/__init__.py


+ 0 - 62
src/petals/bloom/block.py

@@ -1,62 +0,0 @@
-"""
-Bloom intermediate layer
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-import os
-from typing import Optional, Tuple
-
-import torch.nn.quantized.dynamic.modules.linear
-import transformers
-from packaging import version
-from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
-
-if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
-    assert (
-        version.parse("4.25.1") <= version.parse(transformers.__version__) < version.parse("5.0.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.25.1,<5.0.0"
-
-
-class WrappedBloomBlock(BloomBlock):
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        *args,
-        attention_mask: Optional[torch.Tensor] = None,
-        alibi: Optional[torch.Tensor] = None,
-        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
-        **kwargs
-    ):
-        assert attention_mask is None
-        batch_size, seq_length = hidden_states.shape[:2]
-        past_length = 0 if layer_past is None else layer_past[0].shape[-1]
-        seq_length_with_past = seq_length + past_length
-        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
-        if alibi is None:
-            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
-        return super().forward(
-            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
-        )
-
-    def _prepare_attn_mask(
-        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
-    ) -> torch.BoolTensor:
-        # create causal mask
-        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
-        combined_attention_mask = None
-        device = attention_mask.device
-        _, src_length = input_shape
-
-        if src_length > 1:
-            combined_attention_mask = _make_causal_mask(
-                torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
-            )
-
-        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
-        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
-        combined_attention_mask = (
-            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
-        )
-
-        return combined_attention_mask

+ 0 - 132
src/petals/bloom/from_pretrained.py

@@ -1,132 +0,0 @@
-"""
-Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
-If necessary, one can rewrite this to implement a different behavior, such as:
- - loading files from a local data source (e.g. S3)
- - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
- - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
-
-"""
-from __future__ import annotations
-
-import itertools
-import time
-from typing import Optional, OrderedDict, Union
-
-import torch
-from accelerate import init_empty_weights
-from accelerate.utils import set_module_tensor_to_device
-from hivemind.utils.logging import get_logger
-from transformers.modeling_utils import WEIGHTS_NAME
-from transformers.models.bloom.configuration_bloom import BloomConfig
-from transformers.utils import get_file_from_repo
-
-from petals.bloom.block import WrappedBloomBlock
-from petals.server.block_utils import get_block_size, resolve_block_dtype
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
-
-logger = get_logger(__name__)
-
-CLIENT_BRANCH = "main"
-BLOCK_BRANCH_PREFIX = "block_"
-
-
-def load_pretrained_block(
-    converted_model_name_or_path: str,
-    block_index: int,
-    config: Optional[BloomConfig] = None,
-    torch_dtype: Union[torch.dtype, str] = "auto",
-    use_auth_token: Optional[str] = None,
-    cache_dir: Optional[str] = None,
-    max_disk_space: Optional[int] = None,
-) -> WrappedBloomBlock:
-    """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
-    assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-    torch_dtype = resolve_block_dtype(config, torch_dtype)
-
-    if config is None:
-        config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
-    if cache_dir is None:
-        cache_dir = DEFAULT_CACHE_DIR
-
-    with init_empty_weights():
-        block = WrappedBloomBlock(config)
-
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        use_auth_token=use_auth_token,
-        cache_dir=cache_dir,
-        max_disk_space=max_disk_space,
-    )
-
-    # dummy load, check that keys match
-    report = block.load_state_dict(state_dict, strict=True)
-    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
-
-    for param_name, _ in block.named_parameters():
-        assert param_name in state_dict, f"{param_name} not in state dict"
-        param = state_dict[param_name]
-        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
-            param = param.to(torch_dtype)
-        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
-
-    logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
-    return block
-
-
-def _load_state_dict(
-    pretrained_model_name_or_path: str,
-    block_index: int,
-    config: BloomConfig,
-    *,
-    use_auth_token: Optional[str] = None,
-    cache_dir: str,
-    max_disk_space: Optional[int] = None,
-    min_backoff: float = 5,
-) -> OrderedDict[str, torch.Tensor]:
-    revision = BLOCK_BRANCH_PREFIX + str(block_index)
-
-    # First, try to find the weights locally
-    try:
-        with allow_cache_reads(cache_dir):
-            archive_file = get_file_from_repo(
-                pretrained_model_name_or_path,
-                filename=WEIGHTS_NAME,
-                revision=revision,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-                local_files_only=True,
-            )
-            if archive_file is not None:
-                return torch.load(archive_file, map_location="cpu")
-    except Exception:
-        logger.debug(
-            f"Failed to load block {block_index} from cache. The block will be downloaded again", exc_info=True
-        )
-
-    # If not found, ensure that we have enough disk space to download them (maybe remove something)
-    for attempt_no in itertools.count():
-        try:
-            with allow_cache_writes(cache_dir):
-                block_size = get_block_size(config, "disk")
-                free_disk_space_for(
-                    pretrained_model_name_or_path, block_size, cache_dir=cache_dir, max_disk_space=max_disk_space
-                )
-
-                archive_file = get_file_from_repo(
-                    pretrained_model_name_or_path,
-                    filename=WEIGHTS_NAME,
-                    revision=revision,
-                    use_auth_token=use_auth_token,
-                    cache_dir=cache_dir,
-                    local_files_only=False,
-                )
-                return torch.load(archive_file, map_location="cpu")
-        except Exception as e:
-            delay = min_backoff * (2**attempt_no)
-            logger.warning(f"Failed to load block {block_index} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
-            time.sleep(delay)
-
-
-DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 0 - 20
src/petals/cli/config.json

@@ -1,20 +0,0 @@
-{
-  "apply_residual_connection_post_layernorm": false,
-  "attention_dropout": 0.0,
-  "attention_softmax_in_fp32": true,
-  "bos_token_id": 1,
-  "eos_token_id": 2,
-  "hidden_dropout": 0.0,
-  "initializer_range": 0.02,
-  "layer_norm_epsilon": 1e-05,
-  "masked_softmax_fusion": true,
-  "model_type": "bloom",
-  "n_embed": 14336,
-  "n_layer": 70,
-  "num_attention_heads": 112,
-  "pretraining_tp": 4,
-  "slow_but_exact": false,
-  "transformers_version": "4.20.0.dev0",
-  "use_cache": true,
-  "vocab_size": 250880
-}

+ 0 - 96
src/petals/cli/convert_model.py

@@ -1,96 +0,0 @@
-import argparse
-import os
-
-import psutil
-import torch.backends.quantized
-import torch.nn as nn
-import transformers
-from hivemind.utils.logging import get_logger
-from huggingface_hub import HfApi, Repository
-from tqdm.auto import tqdm
-from transformers.models.bloom.modeling_bloom import BloomModel
-
-from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH, DTYPE_MAP
-from petals.client import DistributedBloomConfig
-
-logger = get_logger(__name__)
-
-
-def main():
-    parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
-
-    parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
-    parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
-    parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
-    parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
-    parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
-    parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
-    parser.add_argument(
-        "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
-    )
-    parser.add_argument(
-        "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
-    )
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
-    args = parser.parse_args()
-
-    free_ram_gb = psutil.virtual_memory().available / 2**30
-    if args.model == "bigscience/bloom" and free_ram_gb < 400:
-        logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
-
-    assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
-    if os.path.exists(args.output_path) and (
-        len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
-    ):
-        raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
-
-    logger.info(f"Loading source model {args.model} (this may take a few minutes)")
-    config = DistributedBloomConfig.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    config.dht_prefix = args.output_repo
-
-    model = BloomModel.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
-    )
-    if args.resize_token_embeddings:
-        logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
-        model.resize_token_embeddings(args.resize_token_embeddings)
-        config.vocab_size = args.resize_token_embeddings
-
-    tokenizer = transformers.AutoTokenizer.from_pretrained(
-        args.model, use_auth_token=args.use_auth_token, revision=args.revision
-    )
-    os.makedirs(args.output_path, exist_ok=True)
-
-    api = HfApi(token=args.use_auth_token)
-    api.create_repo(args.output_repo, repo_type="model", exist_ok=True)
-    repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
-    repo.git_pull()
-
-    transformer_blocks = model.h
-    logger.info(
-        f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
-        f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
-    )
-    for i, block in enumerate(tqdm(transformer_blocks)):
-        repo.git_checkout(args.client_branch, create_branch_ok=True)
-        with repo.commit(
-            commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
-        ):
-            torch.save(block.state_dict(), "./pytorch_model.bin")
-
-    logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
-    repo.git_checkout(args.client_branch, create_branch_ok=True)
-    with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
-        model.h = nn.ModuleList()
-        model.save_pretrained(".")
-        tokenizer.save_pretrained(".")
-        config.save_pretrained(".")
-
-    logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
-
-
-if __name__ == "__main__":
-    main()

+ 1 - 1
src/petals/cli/inference_one_block.py

@@ -6,7 +6,7 @@ from tqdm.auto import trange
 from transformers import BloomConfig
 from transformers.models.bloom.modeling_bloom import build_alibi_tensor
 
-from petals.bloom.block import BloomBlock
+from petals.models.bloom.block import BloomBlock
 
 logger = get_logger(__name__)
 

+ 1 - 1
src/petals/cli/run_server.py

@@ -87,7 +87,7 @@ def main():
     parser.add_argument('--alloc_timeout', type=float, default=60,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
-    parser.add_argument('--revision', type=str, default='main',
+    parser.add_argument('--revision', type=str, default=None,
                         help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 

+ 0 - 6
src/petals/client/__init__.py

@@ -1,10 +1,4 @@
 from petals.client.inference_session import InferenceSession
-from petals.client.remote_model import (
-    DistributedBloomConfig,
-    DistributedBloomForCausalLM,
-    DistributedBloomForSequenceClassification,
-    DistributedBloomModel,
-)
 from petals.client.remote_sequential import RemoteSequential
 from petals.client.routing.sequence_manager import RemoteSequenceManager
 from petals.client.routing.spending_policy import NoSpendingPolicy, SpendingPolicyBase

+ 94 - 0
src/petals/client/from_pretrained.py

@@ -0,0 +1,94 @@
+import contextlib
+import json
+import os
+import re
+import tempfile
+import threading
+from typing import List, Optional, Tuple, Union
+
+import torch
+from hivemind.utils.logging import get_logger
+from transformers import BloomPreTrainedModel, modeling_utils
+
+from petals.utils.version import get_compatible_model_repo
+
+logger = get_logger(__name__)
+
+
+class FromPretrainedMixin:
+    @classmethod
+    def from_pretrained(
+        cls,
+        model_name_or_path: Union[str, os.PathLike, None],
+        *args,
+        low_cpu_mem_usage: Optional[bool] = None,
+        torch_dtype: Optional[Union[str, torch.dtype]] = None,
+        **kwargs,
+    ):
+        model_name_or_path = get_compatible_model_repo(model_name_or_path)
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        if torch_dtype is None:
+            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
+            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
+            torch_dtype = "auto"
+
+        with ignore_keys(cls._keys_to_ignore_on_load_unexpected):
+            return super().from_pretrained(
+                model_name_or_path, *args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs
+            )
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    ).replace(
+        "torch_dtype (`str` or `torch.dtype`, *optional*)",
+        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
+    )
+
+
+_shard_config = threading.local()
+_shard_config.ignored_keys = None
+
+
+@contextlib.contextmanager
+def ignore_keys(patterns: List[str]):
+    try:
+        prev_patterns = _shard_config.ignored_keys
+        _shard_config.ignored_keys = patterns
+        yield
+    finally:
+        _shard_config.ignored_keys = prev_patterns
+
+
+def patched_get_checkpoint_shard_files(
+    pretrained_model_name_or_path, index_filename, *args, **kwargs
+) -> Tuple[List[str], dict]:
+    """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
+
+    should_ignore_keys = _shard_config.ignored_keys is not None
+    tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
+    with tempdir_ctx as tempdir:
+        if should_ignore_keys:
+            with open(index_filename) as f:
+                index = json.load(f)
+            n_original_shards = len(set(index["weight_map"].values()))
+
+            index["weight_map"] = {
+                param_name: filename
+                for param_name, filename in index["weight_map"].items()
+                if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
+            }
+            n_loaded_shards = len(set(index["weight_map"].values()))
+            logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")
+
+            # Replace the original index with a patched JSON, where ignored keys are removed
+            index_filename = os.path.join(tempdir, "pytorch_model.bin.index.json")
+            with open(index_filename, "w") as f:
+                json.dump(index, f)
+
+        return original_get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, *args, **kwargs)
+
+
+original_get_checkpoint_shard_files = modeling_utils.get_checkpoint_shard_files
+modeling_utils.get_checkpoint_shard_files = patched_get_checkpoint_shard_files

+ 28 - 44
src/petals/bloom/modeling_utils.py → src/petals/client/lm_head.py

@@ -1,10 +1,6 @@
-"""
-PyTorch BLOOM model that implements several memory-efficient modes.
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-
+import dataclasses
 import platform
+from typing import Optional, Union
 
 import psutil
 import torch
@@ -12,21 +8,30 @@ import torch.nn.functional as F
 import torch.utils.checkpoint
 from hivemind import get_logger
 from torch import nn
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
 logger = get_logger(__name__)
 
 
-class LMHead(nn.Module):
-    """
-    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
-    """
+@dataclasses.dataclass
+class LMHeadConfig:
+    # This settings matter for running the client with dtype bfloat16 on CPU.
+    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
+    use_chunked_forward: Union[str, bool] = "auto"
+    chunked_forward_step: int = 16384
+
 
-    def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
+class LMHead(nn.Module):
+    def __init__(self, config: PretrainedConfig):
         super().__init__()
-        self.word_embeddings = word_embeddings
+
+        if not config.tie_word_embeddings:
+            self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False))
+        else:
+            self.weight = None  # Will be set to get_input_embeddings().weight during loading the model
+        self.bias = None
+        self.in_features = config.hidden_size  # Similar to nn.Linear attributes
+        self.out_features = config.vocab_size
 
         self.use_chunked_forward = config.use_chunked_forward
         if self.use_chunked_forward == "auto":
@@ -42,35 +47,17 @@ class LMHead(nn.Module):
         self.chunked_forward_step = config.chunked_forward_step
         self._bf16_warning_shown = False
 
-    @property
-    def in_features(self) -> int:
-        return self.word_embeddings.num_embeddings
-
-    @property
-    def out_features(self) -> int:
-        return self.word_embeddings.embedding_dim
-
-    @property
-    def weight(self):
-        return self.word_embeddings.weight
-
-    @property
-    def bias(self):
-        return None
-
     def forward(self, hidden_states):
-        word_embeddings = self.word_embeddings.weight
-
         if (
-            word_embeddings.dtype in [torch.float16, torch.bfloat16]
-            and word_embeddings.device.type == "cpu"
+            self.weight.dtype in [torch.float16, torch.bfloat16]
+            and self.weight.device.type == "cpu"
             and self.use_chunked_forward
         ):
             lm_logits = self.chunked_forward(hidden_states)
         else:
             # Switch dtype in case word_embeddings are fp16/bf16
-            hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings)
+            hidden_states = hidden_states.to(self.weight.dtype)
+            lm_logits = F.linear(hidden_states, self.weight)
         return lm_logits
 
     def chunked_forward(self, hidden_states):
@@ -80,20 +67,17 @@ class LMHead(nn.Module):
         assert self.chunked_forward_step > 0, "Chunk size for chunked forward must be positive"
 
         if not self._bf16_warning_shown:
-            if self.word_embeddings.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
+            if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
                 logger.warning(
                     "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
                     "Consider loading the model with torch_dtype='float32'"
                 )
             self._bf16_warning_shown = True
 
-        word_embeddings = self.word_embeddings.weight
-        num_embeddings = self.word_embeddings.num_embeddings
-
         hidden_states = hidden_states.float()
-        output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
+        output = torch.empty(*hidden_states.shape[:-1], self.out_features)
 
-        for i in range(0, num_embeddings, self.chunked_forward_step):
-            chunk = word_embeddings[i : i + self.chunked_forward_step].float()
+        for i in range(0, self.out_features, self.chunked_forward_step):
+            chunk = self.weight[i : i + self.chunked_forward_step].float()
             output[..., i : i + self.chunked_forward_step] = F.linear(hidden_states, chunk)
         return output

+ 88 - 0
src/petals/client/ptune.py

@@ -0,0 +1,88 @@
+import dataclasses
+from contextlib import contextmanager
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from hivemind import get_logger
+from transformers import PretrainedConfig
+
+from petals.utils.misc import DUMMY
+
+logger = get_logger(__name__)
+
+
+@dataclasses.dataclass
+class PTuneConfig:
+    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
+    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
+
+
+class PTuneMixin:
+    _keys_to_ignore_on_load_missing = [r"(intermediate_)?prompt_embeddings\.weight$"]
+
+    def init_prompts(self, config: PretrainedConfig) -> None:
+        if config.tuning_mode and "ptune" in config.tuning_mode:
+            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+            self.pre_seq_len = config.pre_seq_len
+            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+            with force_non_empty_weights():
+                # Prompt embeddings and their optimizer stats are kept in float32 to increase ptune quality
+                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
+                if config.tuning_mode == "deep_ptune":
+                    self.intermediate_prompt_embeddings = nn.Embedding(
+                        self.pre_seq_len,
+                        config.num_hidden_layers * config.hidden_size,
+                        # ^-- TODO: should be num_hidden_layers - 1
+                        dtype=torch.float32,
+                    )
+        elif config.tuning_mode:
+            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
+
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad = value
+
+    def get_prompt(self, batch_size):
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
+        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+
+        if self.config.tuning_mode == "deep_ptune":
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size,
+                self.pre_seq_len,
+                self.config.num_hidden_layers,
+                self.config.hidden_size
+                # TODO: should be num_hidden_layers - 1
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+        else:
+            intermediate_prompts = DUMMY
+
+        dtype = self.word_embeddings.weight.dtype
+        return prompts.to(dtype), intermediate_prompts.to(dtype)
+
+
+_original_register_parameter = nn.Module.register_parameter
+
+
+@contextmanager
+def force_non_empty_weights():
+    """
+    This context manager allows to bypass the accelerate.init_empty_weights() context manager
+    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
+    The transformers library should replace all meta tensors by empty tensors by itself
+    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
+
+    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
+    """
+
+    try:
+        possibly_patched_register_parameter = nn.Module.register_parameter
+        nn.Module.register_parameter = _original_register_parameter
+        yield
+    finally:
+        nn.Module.register_parameter = possibly_patched_register_parameter

+ 0 - 268
src/petals/client/remote_model.py

@@ -1,268 +0,0 @@
-from contextlib import contextmanager
-from typing import List, Optional, Union
-
-import hivemind
-import torch
-import torch.nn as nn
-from hivemind.utils.logging import get_logger
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.models.bloom import (
-    BloomConfig,
-    BloomForCausalLM,
-    BloomForSequenceClassification,
-    BloomModel,
-    BloomPreTrainedModel,
-)
-
-from petals.bloom.modeling_utils import LMHead
-from petals.client.remote_generation import RemoteGenerationMixin
-from petals.client.remote_sequential import RemoteSequential
-from petals.client.routing.sequence_manager import SequenceManagerConfig
-from petals.constants import PUBLIC_INITIAL_PEERS
-from petals.utils.misc import DUMMY
-
-logger = get_logger(__name__)
-
-
-class DistributedBloomConfig(BloomConfig, SequenceManagerConfig):
-    """
-    A bloom config that contains information about DHT peers.
-    To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
-    """
-
-    initial_peers: List[str] = PUBLIC_INITIAL_PEERS  # a list of initial peers for hivemind DHT
-    dht_prefix: str  # a prefix for all dht keys that correspond to this model (usually equal to model name)
-    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
-
-    pre_seq_len: int = 0  # a number of tokens for prompt tuning.
-    tuning_mode: Optional[str] = None  # fine-tuning regime, one of [None, "ptune", "deep_ptune"]
-
-    # This settings matter for running the client with dtype bfloat16 on CPU.
-    # If the CPU doesn't support AVX512, chunked_forward() significantly speeds up computations.
-    use_chunked_forward: Union[str, bool] = "auto"
-    chunked_forward_step: int = 16384
-
-
-original_register_parameter = nn.Module.register_parameter
-
-
-@contextmanager
-def force_non_empty_weights():
-    """
-    This context manager allows to bypass the accelerate.init_empty_weights() context manager
-    (that forces all nn.Parameters to be PyTorch's meta tensors) used when low_cpu_mem_usage=True.
-    The transformers library should replace all meta tensors by empty tensors by itself
-    but this feature does not work due to a bug ([1] fails if `add_prefix_to_model == True`).
-
-    [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
-    """
-
-    try:
-        possibly_patched_register_parameter = nn.Module.register_parameter
-        nn.Module.register_parameter = original_register_parameter
-        yield
-    finally:
-        nn.Module.register_parameter = possibly_patched_register_parameter
-
-
-class _FromPretrainedDefaultsMixin:
-    @classmethod
-    def from_pretrained(
-        cls,
-        *args,
-        low_cpu_mem_usage: Optional[bool] = None,
-        torch_dtype: Optional[Union[str, torch.dtype]] = None,
-        **kwargs,
-    ):
-        if low_cpu_mem_usage is None:
-            low_cpu_mem_usage = True
-        if torch_dtype is None:
-            # torch_dtype=None gives torch.float32 in transformers>=4.26.0. In contrast,
-            # torch_dtype="auto" attempts to (1) use config.torch_dtype (if exists), (2) use dtype of the weights.
-            torch_dtype = "auto"
-        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, **kwargs)
-
-    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
-        "low_cpu_mem_usage(`bool`, *optional*)",
-        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
-    ).replace(
-        "torch_dtype (`str` or `torch.dtype`, *optional*)",
-        'torch_dtype (`str` or `torch.dtype`, *optional*, defaults to `"auto"` in Petals)',
-    )
-
-
-class DistributedBloomModel(_FromPretrainedDefaultsMixin, BloomModel):
-    """BloomModel, but all transformer layers are hosted by the swarm"""
-
-    _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
-        r"^(intermediate_)?prompt_embeddings\.weight$",
-    ]
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
-        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
-        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
-
-        n_layer, config.n_layer = config.n_layer, 0  # temporarily set n_layer to 0 to prevent layer initialization
-        super().__init__(config)
-        assert len(self.h) == 0
-        config.n_layer = n_layer
-
-        self.h = RemoteSequential(config, dht=dht)
-
-        # Forbid accumulate grads for embeddings and layernorm
-        self.set_requires_grad(False)
-
-        if config.tuning_mode and "ptune" in config.tuning_mode:
-            assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-            self.pre_seq_len = config.pre_seq_len
-            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-
-            with force_non_empty_weights():
-                if self.word_embeddings_layernorm.weight.dtype in (torch.float16, torch.bfloat16):
-                    logger.info(
-                        "Prompt embeddings and their optimizer statistics will be kept in float32 "
-                        "to increase ptune quality"
-                    )
-                self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size, dtype=torch.float32)
-                if config.tuning_mode == "deep_ptune":
-                    self.intermediate_prompt_embeddings = nn.Embedding(
-                        self.pre_seq_len,
-                        config.num_hidden_layers * config.hidden_size,
-                        # ^-- TODO: should be num_hidden_layers - 1
-                        dtype=torch.float32,
-                    )
-        elif config.tuning_mode:
-            raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
-
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
-        prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-
-        if self.config.tuning_mode == "deep_ptune":
-            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
-            intermediate_prompts = intermediate_prompts.view(
-                batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size  # TODO: should be len(self.h) - 1
-            )
-            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
-        else:
-            intermediate_prompts = DUMMY
-
-        dtype = self.word_embeddings.weight.dtype
-        return prompts.to(dtype), intermediate_prompts.to(dtype)
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
-    ):
-        assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
-
-        for k, v in kwargs.items():
-            if not (v is None or v is False):
-                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
-
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        elif input_ids is not None:
-            input_shape = input_ids.size()
-            input_ids = input_ids.view(-1, input_shape[-1])
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            batch_size = inputs_embeds.shape[0]
-            prompts, intermediate_prompts = self.get_prompt(batch_size)
-            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
-        output_shape = input_shape + (hidden_states.size(-1),)
-
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
-        else:
-            hidden_states = self.h(hidden_states)
-
-        # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = hidden_states[:, self.pre_seq_len :]
-
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-        hidden_states = hidden_states.view(output_shape)
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=None,
-            hidden_states=None,
-            attentions=None,
-        )
-
-
-class DistributedBloomForCausalLM(_FromPretrainedDefaultsMixin, RemoteGenerationMixin, BloomForCausalLM):
-    """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
-
-    _keys_to_ignore_on_load_missing = (
-        BloomForCausalLM._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-        + [r"^lm_head.word_embeddings\.weight$"]  # Missing since they are shared with input embeddings
-    )
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig):
-        BloomPreTrainedModel.__init__(self, config)
-        self.transformer = DistributedBloomModel(config)
-        self.lm_head = LMHead(config, self.transformer.word_embeddings)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_input_embeddings(self):
-        return self.transformer.word_embeddings
-
-    def get_output_embeddings(self):
-        if self.config.tie_word_embeddings:
-            return None
-        return self.lm_head
-
-    def set_input_embeddings(self, new_embeddings: nn.Embedding):
-        assert isinstance(new_embeddings, nn.Embedding)
-        self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
-        assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
-
-    def set_output_embeddings(self, new_lm_head: nn.Linear):
-        with torch.no_grad():
-            self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
-            self.lm_head.bias[...] = new_lm_head.bias
-
-
-class DistributedBloomForSequenceClassification(_FromPretrainedDefaultsMixin, BloomForSequenceClassification):
-    _keys_to_ignore_on_load_missing = (
-        BloomForSequenceClassification._keys_to_ignore_on_load_missing
-        + DistributedBloomModel._keys_to_ignore_on_load_missing
-    )
-
-    config_class = DistributedBloomConfig
-
-    def __init__(self, config: DistributedBloomConfig):
-        BloomPreTrainedModel.__init__(self, config)
-        self.num_labels = config.num_labels
-
-        self.transformer = DistributedBloomModel(config)
-        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
-
-        # Initialize weights and apply final processing
-        self.post_init()

+ 3 - 4
src/petals/client/remote_sequential.py

@@ -6,9 +6,8 @@ import torch
 from hivemind import DHT, get_logger
 from torch import nn
 
-import petals.client
 from petals.client.inference_session import InferenceSession
-from petals.client.routing.sequence_manager import RemoteSequenceManager
+from petals.client.routing.sequence_manager import RemoteSequenceManager, SequenceManagerConfig
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
 from petals.utils.misc import DUMMY
@@ -23,7 +22,7 @@ class RemoteSequential(nn.Module):
 
     def __init__(
         self,
-        config: petals.client.DistributedBloomConfig,
+        config: SequenceManagerConfig,
         *,
         sequence_manager: Optional[RemoteSequenceManager] = None,
         dht: Optional[DHT] = None,
@@ -40,7 +39,7 @@ class RemoteSequential(nn.Module):
             if start_block is None:
                 start_block = 0
             if end_block is None:
-                end_block = self.config.n_layer
+                end_block = self.config.num_hidden_layers
             block_uids = tuple(f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block, end_block))
             sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht)
         self.sequence_manager = sequence_manager

+ 8 - 1
src/petals/client/routing/sequence_manager.py

@@ -20,6 +20,7 @@ from hivemind.utils.logging import get_logger
 import petals.dht_utils
 from petals.client.routing.sequence_info import RemoteSequenceInfo
 from petals.client.routing.spending_policy import NoSpendingPolicy
+from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import ModuleUID, RemoteSpanInfo, ServerState
 from petals.server.handler import TransformerConnectionHandler
 
@@ -28,6 +29,10 @@ logger = get_logger(__name__)
 
 @dataclasses.dataclass
 class SequenceManagerConfig:
+    initial_peers: Sequence[str] = tuple(PUBLIC_INITIAL_PEERS)  # a list of initial peers for hivemind DHT
+    dht_prefix: Optional[str] = None  # a prefix for all dht keys that correspond to this model (default: model name)
+    daemon_startup_timeout: int = 60  # timeout for the libp2p daemon connecting to initial peers
+
     allowed_servers: Optional[Collection[Union[PeerID, str]]] = None  # if defined, send requests only to these servers
 
     request_timeout: float = 3 * 60  # timeout for forward/backward/inference requests
@@ -73,6 +78,8 @@ class RemoteSequenceManager:
         dht: Optional[DHT] = None,
         state: Optional[SequenceManagerState] = None,
     ):
+        assert config.initial_peers or dht is not None, "Please specify `config.initial_peers` or `dht`"
+        assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
         assert len(block_uids) > 0, "Sequences must contain at least one block"
 
         self.config = config
@@ -84,7 +91,7 @@ class RemoteSequenceManager:
             dht = DHT(
                 initial_peers=config.initial_peers,
                 client_mode=True,
-                num_workers=config.n_layer,
+                num_workers=config.num_hidden_layers,
                 startup_timeout=config.daemon_startup_timeout,
                 start=True,
             )

+ 2 - 0
src/petals/models/__init__.py

@@ -0,0 +1,2 @@
+from petals.models.bloom import *
+from petals.models.llama import *

+ 7 - 0
src/petals/models/bloom/__init__.py

@@ -0,0 +1,7 @@
+from petals.models.bloom.block import WrappedBloomBlock
+from petals.models.bloom.config import DistributedBloomConfig
+from petals.models.bloom.model import (
+    DistributedBloomForCausalLM,
+    DistributedBloomForSequenceClassification,
+    DistributedBloomModel,
+)

+ 32 - 0
src/petals/models/bloom/block.py

@@ -0,0 +1,32 @@
+"""
+Bloom intermediate layer
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
+
+
+class WrappedBloomBlock(BloomBlock):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs
+    ):
+        assert attention_mask is None, "Non-causal attention masks are not supported yet"
+        batch_size, seq_length = hidden_states.shape[:2]
+        past_length = 0 if layer_past is None else layer_past[0].shape[-1]
+        seq_length_with_past = seq_length + past_length
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+        if alibi is None:
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
+        return super().forward(
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
+        )

+ 35 - 0
src/petals/models/bloom/config.py

@@ -0,0 +1,35 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.bloom import BloomConfig
+from transformers.models.bloom.modeling_bloom import BloomAttention
+
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.client.routing.sequence_manager import SequenceManagerConfig
+from petals.models.bloom.block import WrappedBloomBlock
+from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.version import get_compatible_model_repo
+
+logger = get_logger(__name__)
+
+
+class DistributedBloomConfig(BloomConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedBloomBlock
+    attn_class = BloomAttention
+    block_prefix = "h"
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            # We need "-petals" for backward compatibility with Petals < 1.2.0
+            dht_prefix = str(model_name_or_path) + "-petals"
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+
+
+AutoDistributedConfig.register(DistributedBloomConfig)

+ 134 - 0
src/petals/models/bloom/model.py

@@ -0,0 +1,134 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.bloom.config import DistributedBloomConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
+    """BloomModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = (
+        BloomModel._keys_to_ignore_on_load_missing + PTuneMixin._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = [r"^h\."]
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.h) == 0
+        config.num_hidden_layers = n_layer
+
+        self.h = RemoteSequential(config, dht=dht)
+
+        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ):
+        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.h(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
+
+
+class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, BloomForCausalLM):
+    _keys_to_ignore_on_load_missing = (
+        BloomForCausalLM._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+        + [r"^lm_head\."]  # Missing since they are shared with input embeddings
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.transformer = DistributedBloomModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+
+class DistributedBloomForSequenceClassification(FromPretrainedMixin, BloomForSequenceClassification):
+    _keys_to_ignore_on_load_missing = (
+        BloomForSequenceClassification._keys_to_ignore_on_load_missing
+        + DistributedBloomModel._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedBloomConfig
+
+    def __init__(self, config: DistributedBloomConfig):
+        BloomPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.transformer = DistributedBloomModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False).to(config.torch_dtype)
+
+        # Initialize weights and apply final processing
+        self.post_init()

+ 7 - 0
src/petals/models/llama/__init__.py

@@ -0,0 +1,7 @@
+from petals.models.llama.block import WrappedLlamaBlock
+from petals.models.llama.config import DistributedLlamaConfig
+from petals.models.llama.model import (
+    DistributedLlamaForCausalLM,
+    DistributedLlamaForSequenceClassification,
+    DistributedLlamaModel,
+)

+ 87 - 0
src/petals/models/llama/block.py

@@ -0,0 +1,87 @@
+"""
+LLaMA intermediate layer
+Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
+See commit history for authorship.
+"""
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+
+
+class WrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+
+        if position_ids is None:
+            device = hidden_states.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+            )
+        attention_mask = LlamaModel._prepare_decoder_attention_mask(
+            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        )
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_llama_to_bloom(
+                present_key_value, batch_size, seq_length_with_past
+            )
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_llama(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(batch_size, self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_from_llama_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        value_states = value_states.view(batch_size * self.self_attn.num_heads, seq_length, self.self_attn.head_dim)
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)

+ 35 - 0
src/petals/models/llama/config.py

@@ -0,0 +1,35 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.llama import LlamaConfig
+from transformers.models.llama.modeling_llama import LlamaAttention
+
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.client.routing.sequence_manager import SequenceManagerConfig
+from petals.models.llama.block import WrappedLlamaBlock
+from petals.utils.auto_config import AutoDistributedConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedLlamaConfig(LlamaConfig, SequenceManagerConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedLlamaBlock
+    attn_class = LlamaAttention
+    block_prefix = "model.layers"
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            dht_prefix = str(model_name_or_path)
+            if "/" in dht_prefix:  # If present, strip repository name to merge blocks hosted by different accounts
+                dht_prefix = dht_prefix[dht_prefix.rfind("/") + 1 :]
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+        return super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+
+
+AutoDistributedConfig.register(DistributedLlamaConfig)

+ 152 - 0
src/petals/models/llama/model.py

@@ -0,0 +1,152 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.llama.config import DistributedLlamaConfig
+
+logger = get_logger(__name__)
+
+
+class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
+    """LlamaModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = LlamaModel._keys_to_ignore_on_load_unexpected + [r"^model\.layers\."]
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config: DistributedLlamaConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.layers) == 0
+        config.num_hidden_layers = n_layer
+
+        self.layers = RemoteSequential(config, dht=dht)
+
+        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        **kwargs,
+    ) -> BaseModelOutputWithPast:
+        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
+
+        for k, v in kwargs.items():
+            if not (v is None or v is False):
+                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
+
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+
+        hidden_states = inputs_embeds
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
+        else:
+            hidden_states = self.layers(hidden_states)
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.norm(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=None,
+            hidden_states=None,
+            attentions=None,
+        )
+
+    @property
+    def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin
+        return self.embed_tokens
+
+    @property
+    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return nn.Identity()
+
+    @property
+    def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin
+        return self.layers
+
+    @property
+    def ln_f(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return self.norm
+
+
+class DistributedLlamaForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, LlamaForCausalLM):
+    _keys_to_ignore_on_load_missing = DistributedLlamaModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config: DistributedLlamaConfig):
+        LlamaPreTrainedModel.__init__(self, config)
+        self.model = DistributedLlamaModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    @property
+    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin
+        return self.model
+
+
+class DistributedLlamaForSequenceClassification(FromPretrainedMixin, LlamaForSequenceClassification):
+    _keys_to_ignore_on_load_missing = (
+        LlamaForSequenceClassification._keys_to_ignore_on_load_missing
+        + DistributedLlamaModel._keys_to_ignore_on_load_missing
+    )
+    _keys_to_ignore_on_load_unexpected = DistributedLlamaModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedLlamaConfig
+
+    def __init__(self, config):
+        LlamaPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.model = DistributedLlamaModel(config)
+        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @property
+    def transformer(self) -> DistributedLlamaModel:  # For compatibility with RemoteGenerationMixin
+        return self.model

+ 11 - 10
src/petals/server/backend.py

@@ -1,4 +1,3 @@
-"""Code for serving bloom blocks via hivemind-server"""
 from __future__ import annotations
 
 from collections import Counter
@@ -12,8 +11,7 @@ from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 from tensor_parallel import TensorParallel
 from tensor_parallel.tensor_parallel import PerDeviceTensors
-from transformers import BloomConfig
-from transformers.models.bloom.modeling_bloom import BloomAttention
+from transformers import PretrainedConfig
 
 from petals.data_structures import InferenceMetadata
 from petals.server.memory_cache import MemoryCache
@@ -24,17 +22,19 @@ logger = get_logger(__name__)
 
 
 class TransformerBackend(ModuleBackend):
-    """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
+    """A wrapper for a transformer block that can process requests for forward, backward and inference"""
 
-    def __init__(self, *args, config: BloomConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
+    def __init__(
+        self, *args, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs
+    ):
         super().__init__(*args, **kwargs)
         assert isinstance(self.module, TensorParallel)
         self.config = config
         self.memory_cache = memory_cache
         for name, param in self.module.named_parameters():
-            assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+            assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
         for name, buf in self.module.named_buffers():
-            assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
+            assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
 
         max_batch_size = self.forward_pool.max_batch_size
         device = self.module.devices[self.module.output_device_index]
@@ -52,9 +52,10 @@ class TransformerBackend(ModuleBackend):
         self.shard_num_heads = []
         for shard in self.module.module_shards:
             for submodule in shard.modules():
-                if isinstance(submodule, BloomAttention):
+                if isinstance(submodule, config.attn_class):
                     self.shard_num_heads.append(submodule.num_heads)
-        assert len(self.shard_num_heads) == len(self.module.devices) and sum(self.shard_num_heads) == config.n_head
+        assert len(self.shard_num_heads) == len(self.module.devices)
+        assert sum(self.shard_num_heads) == config.num_attention_heads
 
         self.inference_schema = (
             (
@@ -71,7 +72,7 @@ class TransformerBackend(ModuleBackend):
 
     def get_inference_cache_descriptors(self, batch_size: int, max_length: int) -> Sequence[TensorDescriptor]:
         """Create tensor descriptors for attention cache tensors used during inference_step"""
-        head_dim = self.config.hidden_size // self.config.n_head
+        head_dim = self.config.hidden_size // self.config.num_attention_heads
         cache_tensors = []
         for device, num_heads in zip(self.module.devices, self.shard_num_heads):
             keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)

+ 4 - 6
src/petals/server/block_utils.py

@@ -2,12 +2,10 @@ from typing import Optional, Union
 
 import torch
 from accelerate import init_empty_weights
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
-from petals.bloom.block import WrappedBloomBlock
 
-
-def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
+def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
     if dtype not in ("auto", None):
         return dtype
@@ -17,7 +15,7 @@ def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) ->
 
 
 def get_block_size(
-    config: BloomConfig,
+    config: PretrainedConfig,
     location: str,
     *,
     dtype: Optional[Union[str, torch.dtype]] = None,
@@ -30,7 +28,7 @@ def get_block_size(
         ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
 
     with init_empty_weights(include_buffers=True):
-        block = WrappedBloomBlock(config)
+        block = config.block_class(config)
         n_params = sum(param.numel() for param in block.parameters())
 
     if location == "memory" and load_in_8bit:

+ 175 - 0
src/petals/server/from_pretrained.py

@@ -0,0 +1,175 @@
+"""
+Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
+If necessary, one can rewrite this to implement a different behavior, such as:
+ - loading files from a local data source (e.g. S3)
+ - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
+ - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
+
+"""
+import json
+import time
+from typing import Dict, Optional, Union
+
+import torch
+import torch.nn as nn
+from accelerate import init_empty_weights
+from accelerate.utils import set_module_tensor_to_device
+from hivemind.utils.logging import get_logger
+from huggingface_hub import get_hf_file_metadata, hf_hub_url
+from transformers import PretrainedConfig
+from transformers.utils import get_file_from_repo
+
+from petals.server.block_utils import resolve_block_dtype
+from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
+
+logger = get_logger(__name__)
+
+
+def load_pretrained_block(
+    model_name: str,
+    block_index: int,
+    *,
+    config: Optional[PretrainedConfig] = None,
+    torch_dtype: Union[torch.dtype, str] = "auto",
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: Optional[str] = None,
+    max_disk_space: Optional[int] = None,
+) -> nn.Module:
+    if config is None:
+        config = AutoDistributedConfig.from_pretrained(model_name, use_auth_token=use_auth_token)
+    if cache_dir is None:
+        cache_dir = DEFAULT_CACHE_DIR
+
+    assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+    torch_dtype = resolve_block_dtype(config, torch_dtype)
+
+    with init_empty_weights():
+        block = config.block_class(config)
+
+    block_prefix = f"{config.block_prefix}.{block_index}."
+    state_dict = _load_state_dict_from_repo(
+        model_name,
+        block_prefix,
+        revision=revision,
+        use_auth_token=use_auth_token,
+        cache_dir=cache_dir,
+        max_disk_space=max_disk_space,
+    )
+
+    # dummy load, check that keys match
+    report = block.load_state_dict(state_dict, strict=True)
+    assert not report.missing_keys, f"Some block weights are missing: {report.missing_keys}"
+
+    for param_name, _ in block.named_parameters():
+        assert param_name in state_dict, f"{param_name} not in state dict"
+        param = state_dict[param_name]
+        if not str(param.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
+            param = param.to(torch_dtype)
+        set_module_tensor_to_device(block, param_name, "cpu", value=param, dtype=param.dtype)
+
+    logger.info(f"Loaded {model_name} block {block_index}, {report}")
+    return block
+
+
+StateDict = Dict[str, torch.Tensor]
+
+
+def _load_state_dict_from_repo(
+    model_name: str,
+    block_prefix: str,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+) -> StateDict:
+    index_file = get_file_from_repo(
+        model_name, filename="pytorch_model.bin.index.json", use_auth_token=use_auth_token, cache_dir=cache_dir
+    )
+    if index_file is not None:  # Sharded model
+        with open(index_file) as f:
+            index = json.load(f)
+        filenames = {
+            filename for param_name, filename in index["weight_map"].items() if param_name.startswith(block_prefix)
+        }
+        if not filenames:
+            raise RuntimeError(f"Block {block_prefix}* not found in the index: {index['weight_map']}")
+    else:  # Non-sharded model
+        filenames = {"pytorch_model.bin"}
+    logger.debug(f"Loading {block_prefix}* from {filenames}")
+
+    state_dict = {}
+    for filename in filenames:
+        shard_state_dict = _load_state_dict_from_file(
+            model_name,
+            filename,
+            revision=revision,
+            use_auth_token=use_auth_token,
+            cache_dir=cache_dir,
+            max_disk_space=max_disk_space,
+        )
+        shard_state_dict = {
+            param_name[len(block_prefix) :]: param
+            for param_name, param in shard_state_dict.items()
+            if param_name.startswith(block_prefix)
+        }  # Remove unused parameters from memory
+        state_dict.update(shard_state_dict)
+    return state_dict
+
+
+def _load_state_dict_from_file(
+    model_name: str,
+    filename: str,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    delay: float = 30,
+) -> StateDict:
+    # First, try to find the weights locally
+    try:
+        with allow_cache_reads(cache_dir):
+            path = get_file_from_repo(
+                model_name,
+                filename,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            if path is not None:
+                return torch.load(path, map_location="cpu")
+    except Exception:
+        logger.warning(f"Cache for file {filename} is corrupted, it will be downloaded again", exc_info=True)
+
+    # If not found, ensure that we have enough disk space to download them (maybe remove something)
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                url = hf_hub_url(model_name, filename, revision=revision)
+                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                if file_size is not None:
+                    free_disk_space_for(model_name, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size of file {filename} from repo {model_name}")
+
+                path = get_file_from_repo(
+                    model_name,
+                    filename,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+                if path is None:
+                    raise RuntimeError(f"File {filename} does not exist in repo {model_name}")
+                return torch.load(path, map_location="cpu")
+        except Exception as e:
+            logger.warning(f"Failed to load file {filename} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)
+
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")

+ 35 - 29
src/petals/server/server.py

@@ -14,21 +14,23 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
-from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos
 from petals.server import block_selection
 from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
 from petals.server.block_utils import get_block_size, resolve_block_dtype
+from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.throughput import get_dtype_name, get_server_throughput
+from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)
 
@@ -53,7 +55,7 @@ class Server:
         max_batch_size: int = 2048,
         inference_max_length: int = 2048,
         torch_dtype: str = "auto",
-        revision: str = "main",
+        revision: Optional[str] = None,
         cache_dir: Optional[str] = None,
         max_disk_space: Optional[int] = None,
         attn_cache_tokens: int = 8192,
@@ -83,25 +85,32 @@ class Server:
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
 
+        converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)
         self.converted_model_name_or_path = converted_model_name_or_path
+
         self.num_handlers = num_handlers
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
-        self.use_auth_token = use_auth_token
+        self.revision, self.use_auth_token = revision, use_auth_token
 
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
 
+        self.block_config = AutoDistributedConfig.from_pretrained(
+            converted_model_name_or_path,
+            use_auth_token=use_auth_token,
+            revision=revision,
+        )
+
         if prefix is None:
-            prefix = converted_model_name_or_path
-            assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
-                f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
-                f"Please specify --prefix manually when starting a server"
-            )
-            logger.debug(f"Automatic dht prefix: {prefix}")
+            prefix = self.block_config.dht_prefix
+        assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
+            f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
+            f"Please specify another --prefix manually when starting a server"
+        )
         self.prefix = prefix
 
         if expiration is None:
@@ -111,12 +120,9 @@ class Server:
         self.request_timeout = request_timeout
         self.session_timeout, self.step_timeout = session_timeout, step_timeout
 
-        self.block_config = BloomConfig.from_pretrained(
-            converted_model_name_or_path,
-            use_auth_token=use_auth_token,
-            revision=revision,
-        )
-        self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
+        self.module_uids = [
+            f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
+        ]
 
         if dht_client_mode is None:
             is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
@@ -125,7 +131,7 @@ class Server:
         self.dht = DHT(
             initial_peers=initial_peers,
             start=True,
-            num_workers=self.block_config.n_layer,
+            num_workers=self.block_config.num_hidden_layers,
             use_relay=use_relay,
             use_auto_relay=use_auto_relay,
             client_mode=dht_client_mode,
@@ -161,10 +167,10 @@ class Server:
         if load_in_8bit is None:
             load_in_8bit = device.type == "cuda"
         self.load_in_8bit = load_in_8bit
-        logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
+        logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
 
-        max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
-        self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
+        cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
+        self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
 
         assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
         if num_blocks is None and block_indices is None:
@@ -192,6 +198,7 @@ class Server:
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
             throughput = get_server_throughput(
+                converted_model_name_or_path,
                 self.block_config,
                 device,
                 torch_dtype,
@@ -239,11 +246,12 @@ class Server:
         num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
 
+        num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
         logger.info(
             f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
             f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
         )
-        return min(num_blocks, self.block_config.n_layer)
+        return num_blocks
 
     def run(self):
         while True:
@@ -274,6 +282,7 @@ class Server:
                 step_timeout=self.step_timeout,
                 prefetch_batches=self.prefetch_batches,
                 sender_threads=self.sender_threads,
+                revision=self.revision,
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
@@ -352,7 +361,7 @@ class ModuleContainer(threading.Thread):
         dht: DHT,
         prefix: str,
         converted_model_name_or_path: str,
-        block_config: BloomConfig,
+        block_config: PretrainedConfig,
         attn_cache_bytes: int,
         alloc_timeout: float,
         throughput: float,
@@ -366,6 +375,7 @@ class ModuleContainer(threading.Thread):
         compression: CompressionType,
         update_period: float,
         expiration: Optional[float],
+        revision: Optional[str],
         use_auth_token: Optional[str],
         load_in_8bit: bool,
         tensor_parallel_devices: Sequence[torch.device],
@@ -394,14 +404,14 @@ class ModuleContainer(threading.Thread):
                 block = load_pretrained_block(
                     converted_model_name_or_path,
                     block_index,
-                    block_config,
+                    config=block_config,
                     torch_dtype=torch_dtype,
+                    revision=revision,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
                 block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
-
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,
@@ -564,13 +574,9 @@ class ModuleContainer(threading.Thread):
 
         self.ready.clear()
 
+        logger.debug("Shutting down connection handlers")
         for handler in self.conn_handlers:
             handler.shutdown()
-        logger.debug("Connection handlers terminated")
-
-        if self.checkpoint_saver is not None:
-            self.checkpoint_saver.stop.set()
-            self.checkpoint_saver.join()
 
         logger.debug(f"Shutting down pools")
         for pool in self.runtime.pools:

+ 12 - 10
src/petals/server/throughput.py

@@ -5,15 +5,13 @@ import multiprocessing as mp
 import os
 import time
 from collections import Counter
-from hashlib import sha256
 from pathlib import Path
 from typing import Dict, Optional, Sequence, Union
 
 import torch
 from hivemind.utils.logging import get_logger
-from transformers import BloomConfig
+from transformers import PretrainedConfig
 
-from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_block import convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@@ -35,7 +33,8 @@ if not hasattr(speedtest, "Speedtest"):
 
 
 def get_server_throughput(
-    config: BloomConfig,
+    model_name: str,
+    config: PretrainedConfig,
     device: torch.device,
     dtype: Union[str, torch.dtype],
     *,
@@ -59,7 +58,7 @@ def get_server_throughput(
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
 
-        cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
+        cache_key = f"model_{model_name}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
         cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
         if len(tensor_parallel_devices) > 1:
@@ -101,7 +100,7 @@ def get_server_throughput(
 
 
 def measure_throughput_info(
-    config: BloomConfig,
+    config: PretrainedConfig,
     device: torch.device,
     dtype: torch.dtype,
     *,
@@ -127,7 +126,7 @@ def measure_throughput_info(
     return throughput_info
 
 
-def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]:
+def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Optional[float]:
     pipe_recv, pipe_send = mp.Pipe(duplex=False)
     process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
     process.start()
@@ -160,7 +159,7 @@ def _measure_bits_per_second(pipe_send: mp.Pipe):
 
 
 def measure_compute_rps(
-    config: BloomConfig,
+    config: PretrainedConfig,
     device: torch.device,
     dtype: torch.dtype,
     *,
@@ -172,7 +171,7 @@ def measure_compute_rps(
     if not tensor_parallel_devices:
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
-        block = WrappedBloomBlock(config).to(dtype)
+        block = config.block_class(config).to(dtype)
         block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
 
         cache = None
@@ -203,4 +202,7 @@ def get_device_name(device: torch.device) -> str:
 
 
 def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
-    return "8-bit" if load_in_8bit else str(dtype)
+    name = str(dtype)
+    if load_in_8bit:
+        name += ", 8-bit quantized"
+    return name

+ 1 - 0
src/petals/utils/__init__.py

@@ -0,0 +1 @@
+from petals.utils.auto_config import AutoDistributedConfig

+ 23 - 0
src/petals/utils/auto_config.py

@@ -0,0 +1,23 @@
+from typing import Type
+
+from transformers import AutoConfig, PretrainedConfig
+
+CONFIG_MAPPING = {}  # Populated with AutoDistributedConfig.register()
+
+
+class AutoDistributedConfig:
+    @classmethod
+    def from_pretrained(cls, *args, **kwargs) -> PretrainedConfig:
+        config = AutoConfig.from_pretrained(*args, **kwargs)
+        if config.model_type not in CONFIG_MAPPING:
+            raise ValueError(f"Petals does not support model type {config.model_type}")
+
+        dist_config_class = CONFIG_MAPPING[config.model_type]
+        return dist_config_class.from_pretrained(*args, **kwargs)
+
+    @staticmethod
+    def register(config_class: Type[PretrainedConfig]) -> None:
+        assert issubclass(config_class, PretrainedConfig)
+        assert config_class.model_type not in CONFIG_MAPPING
+
+        CONFIG_MAPPING[config_class.model_type] = config_class

+ 15 - 13
src/petals/utils/convert_block.py

@@ -10,18 +10,15 @@ import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tensor_parallel.slicing_configs import get_bloom_config
-from transformers import BloomConfig
-from transformers.models.bloom.modeling_bloom import BloomAttention
-
-from petals.bloom.block import WrappedBloomBlock
+from transformers import PretrainedConfig
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
 def convert_block(
-    block: WrappedBloomBlock,
-    config: BloomConfig,
+    block: nn.Module,
+    config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     load_in_8bit: bool,
@@ -58,7 +55,7 @@ def convert_block(
     return block
 
 
-def replace_8bit_linear(model: nn.Module, threshold=6.0):
+def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module:
     """
     A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
     library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
@@ -100,17 +97,22 @@ def replace_8bit_linear(model: nn.Module, threshold=6.0):
 
 
 def make_tensor_parallel(
-    block: WrappedBloomBlock, model_config: BloomConfig, devices: Sequence[torch.device], output_device: torch.device
-):
-    tp_config = get_bloom_config(model_config, devices)
-    del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    block: nn.Module, model_config: PretrainedConfig, devices: Sequence[torch.device], output_device: torch.device
+) -> nn.Module:
+    if model_config.model_type == "bloom":
+        tp_config = get_bloom_config(model_config, devices)
+        del tp_config.state_rules[re.compile(".*word_embeddings.weight$")]
+    else:
+        if len(devices) > 1:
+            logger.warning("Tensor parallelism is not tested for models other than BLOOM yet, proceed with caution")
+        tp_config = None
     tp_block = tp.TensorParallel(block, devices, config=tp_config, output_device=output_device, delay_init=True)
     total_heads = 0
     for tp_shard in tp_block.module_shards:
         for submodule in tp_shard.modules():
-            if isinstance(submodule, BloomAttention):
+            if isinstance(submodule, model_config.attn_class):
                 total_heads += submodule.num_heads
-    assert total_heads == model_config.n_head
+    assert total_heads == model_config.num_attention_heads
     return tp_block
 
 

+ 5 - 3
src/petals/utils/disk_cache.py

@@ -57,13 +57,16 @@ def free_disk_space_for(
     available_space = shutil.disk_usage(cache_dir).free - os_quota
     if max_disk_space is not None:
         available_space = min(available_space, max_disk_space - occupied_space)
+
+    gib = 1024**3
+    logger.debug(f"Disk space: required {size / gib:.1f} GiB, available {available_space / gib:.1f} GiB")
     if size <= available_space:
         return
 
     revisions = [revision for repo in model_repos for revision in repo.revisions]
     revisions.sort(key=lambda rev: max([item.blob_last_accessed for item in rev.files], default=rev.last_modified))
 
-    # Remove as few least recently used blocks as possible
+    # Remove as few least recently used shards as possible
     pending_removal = []
     freed_space = 0
     extra_space_needed = size - available_space
@@ -73,9 +76,8 @@ def free_disk_space_for(
         if freed_space >= extra_space_needed:
             break
 
-    gib = 1024**3
     if pending_removal:
-        logger.info(f"Removing {len(pending_removal)} blocks to free {freed_space / gib:.1f} GiB of disk space")
+        logger.info(f"Removing {len(pending_removal)} shards to free {freed_space / gib:.1f} GiB of disk space")
         delete_strategy = cache_info.delete_revisions(*pending_removal)
         delete_strategy.execute()
 

+ 19 - 1
src/petals/utils/version.py

@@ -1,3 +1,7 @@
+import os
+import re
+from typing import Union
+
 import requests
 from hivemind.utils.logging import TextStyle, get_logger
 from packaging.version import parse
@@ -7,7 +11,7 @@ import petals
 logger = get_logger(__name__)
 
 
-def validate_version():
+def validate_version() -> None:
     logger.info(f"Running {TextStyle.BOLD}Petals {petals.__version__}{TextStyle.RESET}")
     try:
         r = requests.get("https://pypi.python.org/pypi/petals/json")
@@ -24,3 +28,17 @@ def validate_version():
             )
     except Exception as e:
         logger.warning("Failed to fetch the latest Petals version from PyPI:", exc_info=True)
+
+
+def get_compatible_model_repo(model_name_or_path: Union[str, os.PathLike, None]) -> Union[str, os.PathLike, None]:
+    if model_name_or_path is None:
+        return None
+
+    match = re.fullmatch(r"(bigscience/.+)-petals", str(model_name_or_path))
+    if match is None:
+        return model_name_or_path
+
+    logger.info(
+        f"Loading model from {match.group(1)}, since Petals 1.2.0+ uses original repos instead of converted ones"
+    )
+    return match.group(1)

+ 2 - 2
tests/test_aux_functions.py

@@ -1,7 +1,7 @@
 import pytest
 import torch
 
-from petals.client import DistributedBloomConfig
+from petals import AutoDistributedConfig
 from petals.server.throughput import measure_compute_rps
 from test_utils import MODEL_NAME
 
@@ -9,7 +9,7 @@ from test_utils import MODEL_NAME
 @pytest.mark.forked
 @pytest.mark.parametrize("tensor_parallel", [False, True])
 def test_compute_throughput(tensor_parallel: bool):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
     tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
         config,

+ 12 - 58
tests/test_block_exact_match.py

@@ -1,13 +1,10 @@
 import random
-from typing import Union
 
 import pytest
 import torch
-from transformers.models.bloom.configuration_bloom import BloomConfig
 
-from petals.bloom.block import WrappedBloomBlock
-from petals.bloom.from_pretrained import DTYPE_MAP, _load_state_dict, load_pretrained_block
-from petals.client import DistributedBloomConfig, RemoteSequential
+from petals import DistributedBloomConfig, RemoteSequential
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 
 
@@ -16,21 +13,22 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     remote_sequential = RemoteSequential(config)
 
-    for block_index in random.sample(range(config.n_layer), 3):
+    for block_index in random.sample(range(config.num_hidden_layers), 3):
         remote_block = remote_sequential[block_index]
 
         inputs = torch.randn(1, 8, config.hidden_size)
         outputs_forward = remote_block(inputs)
 
         outputs_inference = []
-        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
-            for i in range(inputs.shape[1]):
-                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-
-            # test that max length is respected
-            with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
-                sess.step(inputs[:, -1:, :])
-            assert "Maximum length exceeded" in repr(exc_info.value)
+        with torch.inference_mode():
+            with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+                for i in range(inputs.shape[1]):
+                    outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+
+                # test that max length is respected
+                with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
+                    sess.step(inputs[:, -1:, :])
+                assert "Maximum length exceeded" in repr(exc_info.value)
         outputs_inference = torch.cat(outputs_inference, dim=1)
 
         ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
@@ -38,47 +36,3 @@ def test_remote_block_exact_match(atol_forward=1e-4, atol_inference=1e-3):
 
         assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
         assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
-
-
-def _old_load_pretrained_block(
-    converted_model_name_or_path: str,
-    block_index: int,
-    torch_dtype: Union[torch.dtype, str] = "auto",
-) -> WrappedBloomBlock:
-    """Load the BLOOM block by directly initializing the weights.
-    This test is used to check consistency with the previous implementation and can be removed in the future."""
-    config = BloomConfig.from_pretrained(converted_model_name_or_path)
-
-    block = WrappedBloomBlock(config)
-    state_dict = _load_state_dict(
-        converted_model_name_or_path,
-        block_index,
-        config,
-        cache_dir=None,
-    )
-
-    if torch_dtype == "auto":
-        with torch.no_grad():
-            for name, param in block.named_parameters():
-                assert name in state_dict, f"{name} not in state dict"
-                param.data = param.data.to(state_dict[name].dtype)
-    else:
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        block = block.to(dtype=torch_dtype)
-
-    block.load_state_dict(state_dict, strict=True)
-    return block
-
-
-@pytest.mark.forked
-def test_init_pretrained_block(torch_dtype=torch.float32, atol_forward=1e-8):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    torch.random.manual_seed(0)
-    inputs = torch.randn(1, 16, config.hidden_size, dtype=torch_dtype)
-
-    block = load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
-    ref_block = _old_load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch_dtype)
-
-    outputs = block.forward(inputs)[0]
-    outputs_ref = ref_block.forward(inputs)[0]
-    assert torch.allclose(outputs, outputs_ref, rtol=0, atol=atol_forward)

+ 2 - 2
tests/test_chained_calls.py

@@ -7,9 +7,9 @@
 import pytest
 import torch
 
-from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import DistributedBloomConfig
+from petals import DistributedBloomConfig
 from petals.client.remote_sequential import RemoteSequential
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 
 

+ 7 - 8
tests/test_dtype.py

@@ -1,17 +1,16 @@
 import pytest
 import torch
 
-from petals.bloom.from_pretrained import load_pretrained_block
-from petals.client import DistributedBloomConfig
 from petals.server.block_utils import resolve_block_dtype
+from petals.server.from_pretrained import load_pretrained_block
+from petals.utils.auto_config import AutoDistributedConfig
 from test_utils import MODEL_NAME
 
 
 @pytest.mark.forked
 @pytest.mark.parametrize("torch_dtype", [torch.float32, torch.float16, "auto"])
-def test_backend_dtype(torch_dtype):
-    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
-    block = load_pretrained_block(MODEL_NAME, 0, config, torch_dtype=torch_dtype)
-    backend_dtype = resolve_block_dtype(config, torch_dtype)
-    other_backend_dtype = next(block.parameters()).dtype if torch_dtype == "auto" else torch_dtype
-    assert backend_dtype == other_backend_dtype
+def test_block_dtype(torch_dtype):
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
+    block = load_pretrained_block(MODEL_NAME, 0, config=config, torch_dtype=torch_dtype)
+    expected_dtype = resolve_block_dtype(config, torch_dtype)
+    assert all(param.dtype == expected_dtype for param in block.parameters())

+ 2 - 2
tests/test_full_model.py

@@ -5,7 +5,7 @@ from hivemind import get_logger
 from transformers.generation import BeamSearchScorer
 from transformers.models.bloom import BloomForCausalLM
 
-from petals.client.remote_model import DistributedBloomForCausalLM
+from petals import DistributedBloomForCausalLM
 from test_utils import *
 
 logger = get_logger(__name__)
@@ -20,7 +20,7 @@ def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, ato
     )
     config = model.config
     assert isinstance(model, DistributedBloomForCausalLM)
-    assert len(model.transformer.h) == model.config.n_layer
+    assert len(model.transformer.h) == model.config.num_hidden_layers
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 

+ 10 - 8
tests/test_remote_sequential.py

@@ -4,10 +4,10 @@ import torch.nn.functional as F
 from hivemind import DHT, BatchTensorDescriptor, get_logger
 from hivemind.proto import runtime_pb2
 
-from petals.bloom.from_pretrained import load_pretrained_block
+from petals import DistributedBloomConfig
 from petals.client import RemoteSequenceManager, RemoteSequential
-from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 
 logger = get_logger(__name__)
@@ -28,10 +28,10 @@ def test_remote_sequential():
     full_grad = test_inputs.grad.clone()
     test_inputs.grad.data.zero_()
 
-    first_half = sequential[: config.n_layer // 2]
-    second_half = sequential[config.n_layer // 2 :]
+    first_half = sequential[: config.num_hidden_layers // 2]
+    second_half = sequential[config.num_hidden_layers // 2 :]
     assert len(first_half) + len(second_half) == len(sequential)
-    assert abs(len(first_half) - len(second_half)) == config.n_layer % 2
+    assert abs(len(first_half) - len(second_half)) == config.num_hidden_layers % 2
     for m in sequential, first_half, second_half:
         assert isinstance(repr(m), str)
 
@@ -46,7 +46,7 @@ def test_remote_sequential():
     assert torch.allclose(test_inputs.grad, full_grad, atol=1e-3)
 
     # test RemoteSequential with lossy compression
-    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
     lossy_sequential = RemoteSequential(
         config, sequence_manager=DummyCustomSequenceManager(config, block_uids, dht=dht)
     )
@@ -90,7 +90,9 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     inputs = F.normalize(torch.randn(batch_size, seq_len, config.hidden_size), dim=-1)
     output_proj = F.normalize(torch.randn(batch_size, seq_len + pre_seq_len, config.hidden_size), dim=-1)
     input_prompts = F.normalize(torch.randn(batch_size, pre_seq_len, config.hidden_size, requires_grad=True), dim=-1)
-    intermediate_prompts = torch.randn(config.n_layer, batch_size, pre_seq_len, config.hidden_size, requires_grad=True)
+    intermediate_prompts = torch.randn(
+        config.num_hidden_layers, batch_size, pre_seq_len, config.hidden_size, requires_grad=True
+    )
 
     input_prompts = input_prompts.detach().requires_grad_(True)
     intermediate_prompts = intermediate_prompts.detach().requires_grad_(True)
@@ -110,7 +112,7 @@ def test_remote_sequential_prompts(batch_size=2, seq_len=5, pre_seq_len=3):
     assert intermediate_prompts_ref.grad is None
 
     outputs_ref = torch.cat([inputs, input_prompts_ref], dim=1)
-    for block_index in range(config.n_layer):
+    for block_index in range(config.num_hidden_layers):
         block_prompt = intermediate_prompts_ref[block_index]
         outputs_ref[:, : block_prompt.shape[1]] += block_prompt
 

+ 2 - 2
tests/test_sequence_manager.py

@@ -5,8 +5,8 @@ import pytest
 import torch
 from hivemind import DHT, get_logger
 
+from petals import DistributedBloomConfig
 from petals.client import RemoteSequenceManager, RemoteSequential
-from petals.client.remote_model import DistributedBloomConfig
 from petals.data_structures import UID_DELIMITER
 from test_utils import *
 
@@ -22,7 +22,7 @@ def test_sequence_manager_basics(mode: str):
     shutdown_evt = threading.Event()
 
     # test RemoteSequential with lossy compression
-    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.n_layer)]
+    block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
     sequential = RemoteSequential(
         config,
         sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),

+ 1 - 1
tests/test_server_stats.py

@@ -4,7 +4,7 @@ import hivemind
 import pytest
 import torch
 
-from petals.client import DistributedBloomConfig, RemoteSequential
+from petals import DistributedBloomConfig, RemoteSequential
 from petals.server.handler import CACHE_TOKENS_AVAILABLE
 from test_utils import *
 

+ 1 - 1
tests/test_tensor_parallel.py

@@ -6,7 +6,7 @@ import transformers
 from tensor_parallel import TensorParallel
 from tensor_parallel.slicing_configs import get_bloom_config
 
-from petals.bloom.from_pretrained import load_pretrained_block
+from petals.server.from_pretrained import load_pretrained_block
 from test_utils import MODEL_NAME