Bläddra i källkod

Add Falcon support (#499)

This PR adds:

- Support for models based on `transformers.FalconModel` (the in-library format for Falcon). Tested on Falcon-40B.
- CI tests for Falcon-RW-1B.
- `--throughput dry_run` option to evaluate throughput and exit right away (implemented by @mryab).

Limitations:

- Backward pass support is broken for now, will be fixed in #500.

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Alexander Borzunov 1 år sedan
förälder
incheckning
dd4a3230bc

+ 6 - 1
.github/workflows/run-tests.yaml

@@ -14,11 +14,13 @@ jobs:
           - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false
     runs-on: ${{ matrix.os }}-latest
-    timeout-minutes: 15
+    timeout-minutes: 20
     steps:
       - name: Increase swap space
         if: ${{ matrix.os == 'ubuntu' }}
@@ -93,6 +95,9 @@ jobs:
 
           # [Step 2] Run PyTest
 
+          # Share disk cache between Petals servers, clients, and HF Transformers
+          export TRANSFORMERS_CACHE=~/.cache/petals
+
           # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
           export no_proxy=*
           export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

+ 3 - 2
src/petals/cli/run_server.py

@@ -106,12 +106,13 @@ def main():
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
-                        type=lambda value: value if value in ['auto', 'eval'] else float(value),
+                        type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
                         default='auto',
                         help='Expected server throughput (a float measured in RPS). '
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
-                             'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
+                             'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
+                             'If set to "dry_run", the script re-evaluates the throughput and exits.')
     parser.add_argument('--update_period', type=float, required=False, default=120,
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 1 - 0
src/petals/models/__init__.py

@@ -1,2 +1,3 @@
 from petals.models.bloom import *
+from petals.models.falcon import *
 from petals.models.llama import *

+ 15 - 0
src/petals/models/falcon/__init__.py

@@ -0,0 +1,15 @@
+from petals.models.falcon.block import WrappedFalconBlock
+from petals.models.falcon.config import DistributedFalconConfig
+from petals.models.falcon.model import (
+    DistributedFalconForCausalLM,
+    DistributedFalconForSequenceClassification,
+    DistributedFalconModel,
+)
+from petals.utils.auto_config import register_model_classes
+
+register_model_classes(
+    config=DistributedFalconConfig,
+    model=DistributedFalconModel,
+    model_for_causal_lm=DistributedFalconForCausalLM,
+    model_for_sequence_classification=DistributedFalconForSequenceClassification,
+)

+ 94 - 0
src/petals/models/falcon/block.py

@@ -0,0 +1,94 @@
+"""
+Falcon intermediate layer
+Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
+See commit history for authorship.
+"""
+from typing import Optional, Tuple
+
+import torch
+from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class WrappedFalconBlock(FalconDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[KVCache] = None,
+        use_cache: bool = False,
+        **kwargs
+    ):
+        batch_size, seq_length = hidden_states.shape[:2]
+
+        if layer_past is not None:
+            layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
+        past_length = 0 if layer_past is None else layer_past[0].shape[1]
+        seq_length_with_past = seq_length + past_length
+
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+        if alibi is None and self.config.alibi:
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            alibi=alibi,
+            layer_past=layer_past,
+            use_cache=use_cache,
+            **kwargs
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        key_states = key_states.permute(0, 2, 1)
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+
+        if self.config.new_decoder_architecture:
+            key_states = self._expand_states(key_states)
+            value_states = self._expand_states(value_states)
+
+        return (key_states, value_states)
+
+    def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        if self.config.new_decoder_architecture:
+            key_states = self._collapse_states(key_states)
+            value_states = self._collapse_states(value_states)
+
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+        key_states = key_states.permute(0, 2, 1)
+
+        return (key_states, value_states)
+
+    def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
+        state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1)  # No copy
+        state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)  # Involves a copy
+        return state
+
+    def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
+        state = state[:, :, 0]
+        state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
+        return state

+ 45 - 0
src/petals/models/falcon/config.py

@@ -0,0 +1,45 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.falcon import FalconConfig
+from transformers.models.falcon.modeling_falcon import FalconAttention
+
+from petals.client.config import ClientConfig
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.models.falcon.block import WrappedFalconBlock
+from petals.utils.auto_config import DefaultRevisionMixin
+
+logger = get_logger(__name__)
+
+
+class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedFalconBlock
+    attn_class = FalconAttention
+    block_prefix = "transformer.h"
+
+    @property
+    def num_key_value_groups(self) -> int:
+        if self.new_decoder_architecture:
+            return self.num_attention_heads // self.num_kv_heads
+        if self.multi_query:
+            return self.num_attention_heads
+        return 1
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            dht_prefix = str(model_name_or_path)
+            dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
+            dht_prefix = dht_prefix.replace(".", "-")
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+
+        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+        config = result[0] if isinstance(result, tuple) else result
+        if config.pad_token_id is None:
+            config.pad_token_id = 0
+        return result

+ 149 - 0
src/petals/models/falcon/model.py

@@ -0,0 +1,149 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.models.falcon import (
+    FalconForCausalLM,
+    FalconForSequenceClassification,
+    FalconModel,
+    FalconPreTrainedModel,
+)
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.falcon.config import DistributedFalconConfig
+from petals.utils.auto_config import DefaultRevisionMixin
+
+logger = get_logger(__name__)
+
+
+class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
+    """FalconModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.h) == 0
+        config.num_hidden_layers = n_layer
+
+        self.h = RemoteSequential(config, dht=dht)
+
+        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[RemotePastKeyValues] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        # The causal mask will be added on the server-side
+        assert (
+            attention_mask is None or (attention_mask == 1).all()
+        ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
+        assert not output_attentions, f"{output_attentions=} is not supported"
+        assert not output_hidden_states, f"{output_hidden_states=} is not supported"
+        assert return_dict is None or return_dict, f"{return_dict=} is not supported"
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        else:
+            prompts = intermediate_prompts = None
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        hidden_states = self.h(
+            hidden_states,
+            prompts=intermediate_prompts,
+            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
+        )
+
+        # Remove prefix
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=RemotePastKeyValues(),
+            hidden_states=None,
+            attentions=None,
+        )
+
+    @property
+    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return nn.Identity()
+
+
+class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
+    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig):
+        FalconPreTrainedModel.__init__(self, config)
+        self.transformer = DistributedFalconModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+
+class DistributedFalconForSequenceClassification(
+    DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
+):
+    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig):
+        FalconPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.transformer = DistributedFalconModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()

+ 9 - 7
src/petals/server/server.py

@@ -5,6 +5,7 @@ import math
 import multiprocessing as mp
 import os
 import random
+import sys
 import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
@@ -186,10 +187,7 @@ class Server:
             check_device_balance(self.tensor_parallel_devices)
 
         if quant_type is None:
-            if device.type == "cuda":
-                quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
-            else:
-                quant_type = QuantType.NONE
+            quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
@@ -234,8 +232,9 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
-        if throughput in ["auto", "eval"]:
+        assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
+        if throughput in ["auto", "eval", "dry_run"]:
+            force_eval = throughput in ["eval", "dry_run"]
             throughput_info = get_server_throughput(
                 converted_model_name_or_path,
                 self.block_config,
@@ -245,9 +244,12 @@ class Server:
                 quant_type=quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 reachable_via_relay=reachable_via_relay,
-                force_eval=(throughput == "eval"),
+                force_eval=force_eval,
                 cache_dir=cache_dir,
             )
+            if throughput == "dry_run":
+                logger.info("Finished estimating throughput, exiting")
+                sys.exit(0)
         else:
             throughput_info = {"throughput": throughput}
         self.server_info = ServerInfo(

+ 34 - 5
src/petals/utils/auto_config.py

@@ -1,12 +1,14 @@
 import os
-import re
 from dataclasses import dataclass
 from typing import Optional, Type, Union
 
+from hivemind import get_logger
 from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
 
 from petals.utils.hf_auth import always_needs_auth
 
+logger = get_logger(__name__)
+
 
 @dataclass
 class _ModelClasses:
@@ -49,17 +51,44 @@ class _AutoDistributedBase:
         return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
 
 
-class AutoDistributedConfig(_AutoDistributedBase):
+class DefaultRevisionMixin:
+    """
+    Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
+    TII models were recently converted to this format but then reverted back due to compatibility issues.
+    We chose to support only the new format since HF staff promised to eventually convert these models
+    to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
+    Until it happens, we override the default `main` revision for the TII repos with the commit
+    pointing out to the model in the in-library format.
+    """
+
+    DEFAULT_REVISIONS = {
+        "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
+        "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
+        "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
+        "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
+    }
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
+    ):
+        if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
+            revision = cls.DEFAULT_REVISIONS[model_name_or_path]
+            logger.info(f"Loading {model_name_or_path}, revision {revision}")
+        return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)
+
+
+class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "config"
 
 
-class AutoDistributedModel(_AutoDistributedBase):
+class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model"
 
 
-class AutoDistributedModelForCausalLM(_AutoDistributedBase):
+class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_causal_lm"
 
 
-class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
+class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_sequence_classification"