Преглед на файлове

Merge remote-tracking branch 'origin/main' into forward_kwargs

Your Name преди 1 година
родител
ревизия
6256995bb1

+ 6 - 1
.github/workflows/run-tests.yaml

@@ -14,11 +14,13 @@ jobs:
           - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
+          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false
     runs-on: ${{ matrix.os }}-latest
-    timeout-minutes: 15
+    timeout-minutes: 20
     steps:
       - name: Increase swap space
         if: ${{ matrix.os == 'ubuntu' }}
@@ -93,6 +95,9 @@ jobs:
 
           # [Step 2] Run PyTest
 
+          # Share disk cache between Petals servers, clients, and HF Transformers
+          export TRANSFORMERS_CACHE=~/.cache/petals
+
           # Necessary for @pytest.mark.forked to work properly on macOS, see https://github.com/kevlened/pytest-parallel/issues/93
           export no_proxy=*
           export OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES

+ 1 - 1
setup.cfg

@@ -40,7 +40,7 @@ install_requires =
     transformers>=4.32.0,<5.0.0  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
-    hivemind @ git+https://github.com/learning-at-home/hivemind
+    hivemind==1.1.10.post2
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2

+ 3 - 2
src/petals/cli/run_server.py

@@ -106,12 +106,13 @@ def main():
                              "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
 
     parser.add_argument('--throughput',
-                        type=lambda value: value if value in ['auto', 'eval'] else float(value),
+                        type=lambda value: value if value in ['auto', 'eval', 'dry_run'] else float(value),
                         default='auto',
                         help='Expected server throughput (a float measured in RPS). '
                              'If set to "auto" (default), the script evaluates network and compute throughput '
                              'on the first run and uses these estimates for future runs. '
-                             'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
+                             'If set to "eval", the script re-evaluates the throughput and overrides the cache. '
+                             'If set to "dry_run", the script re-evaluates the throughput and exits.')
     parser.add_argument('--update_period', type=float, required=False, default=120,
                         help='Server will report blocks to DHT once in this many seconds')
     parser.add_argument('--expiration', type=float, required=False, default=None,

+ 3 - 2
src/petals/client/remote_generation.py

@@ -87,10 +87,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
                 max_new_tokens is None
             ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
 
+            session_max_length = self.transformer.config.pre_seq_len
             if max_length is not None:
-                session_max_length = max_length
+                session_max_length += max_length
             else:
-                session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
+                session_max_length += (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
             context_manager = self.inference_session(max_length=session_max_length)
 
         with context_manager as session:

+ 1 - 0
src/petals/models/__init__.py

@@ -1,2 +1,3 @@
 from petals.models.bloom import *
+from petals.models.falcon import *
 from petals.models.llama import *

+ 3 - 2
src/petals/models/bloom/model.py

@@ -71,7 +71,8 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
+        if use_prompts:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -88,7 +89,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         )
 
         # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state

+ 15 - 0
src/petals/models/falcon/__init__.py

@@ -0,0 +1,15 @@
+from petals.models.falcon.block import WrappedFalconBlock
+from petals.models.falcon.config import DistributedFalconConfig
+from petals.models.falcon.model import (
+    DistributedFalconForCausalLM,
+    DistributedFalconForSequenceClassification,
+    DistributedFalconModel,
+)
+from petals.utils.auto_config import register_model_classes
+
+register_model_classes(
+    config=DistributedFalconConfig,
+    model=DistributedFalconModel,
+    model_for_causal_lm=DistributedFalconForCausalLM,
+    model_for_sequence_classification=DistributedFalconForSequenceClassification,
+)

+ 480 - 0
src/petals/models/falcon/block.py

@@ -0,0 +1,480 @@
+"""
+Falcon intermediate layer
+Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py
+See commit history for authorship.
+"""
+import math
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.falcon.modeling_falcon import (
+    FalconAttention,
+    FalconConfig,
+    FalconDecoderLayer,
+    FalconLinear,
+    FalconMLP,
+    FalconModel,
+    LayerNorm,
+    build_alibi_tensor,
+    dropout_add,
+    rotate_half,
+)
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+INFERENCE_MAX_LENGTH = 8192
+
+
+def apply_rotary(query, key, cos, sin):
+    return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
+
+
+class OptimizedFalconRotaryEmbedding(nn.Module):
+    def __init__(self, head_dim: int, base=10000):
+        super().__init__()
+        inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+        self.head_dim = head_dim
+        self.seq_len_cached = -1
+
+        self.cuda_graph = None
+        self.input_surface = None
+        self.static_outputs = None
+
+    def _optimized_apply_rotary(self, query, key, cos, sin):
+        if self.cuda_graph is None:
+            self.cuda_graph = torch.cuda.CUDAGraph()
+            self.input_surface = (query, key, cos, sin)
+
+            s = torch.cuda.Stream()
+            s.wait_stream(torch.cuda.current_stream())
+            with torch.cuda.stream(s):
+                for _ in range(3):
+                    apply_rotary(*self.input_surface)
+            torch.cuda.current_stream().wait_stream(s)
+
+            with torch.cuda.graph(self.cuda_graph):
+                self.static_outputs = apply_rotary(*self.input_surface)
+
+        inputs = (query, key, cos, sin)
+        for static_input, data in zip(self.input_surface, inputs):
+            static_input.copy_(data)
+        self.cuda_graph.replay()
+        return tuple(o.detach() for o in self.static_outputs)
+
+    def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
+        total_length = seq_len + past_key_values_length
+        if self.seq_len_cached == -1:
+            # warm up the cache
+            total_length = max(INFERENCE_MAX_LENGTH, total_length)
+
+        if total_length > self.seq_len_cached:
+            with torch.inference_mode(False):
+                self.seq_len_cached = total_length
+                t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
+                freqs = torch.einsum("i,j->ij", t, self.inv_freq)
+                emb = torch.cat((freqs, freqs), dim=-1).to(device)
+
+                if dtype in [torch.float16, torch.bfloat16]:
+                    emb = emb.float()
+
+                self.register_buffer("cos_cached", emb.cos()[None, :, :].type(dtype), persistent=False)
+                self.register_buffer("sin_cached", emb.sin()[None, :, :].type(dtype), persistent=False)
+
+        return (
+            self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
+            self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length].type(dtype),
+        )
+
+    def forward(self, query, key, past_key_values_length=0):
+        batch, seq_len, head_dim = query.shape
+        cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
+        if seq_len == 1 and torch.is_inference_mode_enabled() and query.device.type == "cuda":
+            return self._optimized_apply_rotary(query, key, cos, sin)
+        else:
+            return apply_rotary(query, key, cos, sin)
+
+
+def split_heads(
+    fused_qkv: torch.Tensor, num_heads: int, num_kv_heads: int, head_dim: int
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    batch, seq_len, _ = fused_qkv.shape
+    qkv = fused_qkv.view(batch, seq_len, -1, num_heads // num_kv_heads + 2, head_dim)
+    query, key, value = torch.split(qkv, [num_heads // num_kv_heads, 1, 1], dim=3)
+    key = torch.broadcast_to(key, query.shape)
+    value = torch.broadcast_to(value, query.shape)
+
+    query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
+    return query, key, value
+
+
+class OptimizedFalconAttention(FalconAttention):
+    def __init__(self, config: FalconConfig):
+        nn.Module.__init__(self)
+
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.split_size = self.hidden_size
+        self.hidden_dropout = config.hidden_dropout
+
+        if self.head_dim * self.num_heads != self.hidden_size:
+            raise ValueError(
+                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
+                f" {self.num_heads})."
+            )
+
+        self.maybe_rotary = OptimizedFalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
+
+        # Layer-wise attention scaling
+        self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
+        self.beta = self.inv_norm_factor
+        if config.new_decoder_architecture:
+            qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
+        elif config.multi_query:
+            qkv_out_dim = self.hidden_size + 2 * self.head_dim
+        else:
+            qkv_out_dim = 3 * self.hidden_size
+        self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
+        self.new_decoder_architecture = config.new_decoder_architecture
+        self.multi_query = config.multi_query
+        self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
+        self.attention_dropout = nn.Dropout(config.attention_dropout)
+        self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
+
+        if self.new_decoder_architecture:
+            self._split_heads = partial(
+                split_heads, num_heads=self.num_heads, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim
+            )
+            self.split_graph = None
+            self.input_surface = None
+            self.static_outputs = None
+
+    def _optimized_split_heads(self, fused_qkv):
+        if self.split_graph is None:
+            self.split_graph = torch.cuda.CUDAGraph()
+            self.input_surface = fused_qkv
+
+            s = torch.cuda.Stream()
+            s.wait_stream(torch.cuda.current_stream())
+            with torch.cuda.stream(s):
+                for _ in range(3):
+                    self._split_heads(fused_qkv)
+            torch.cuda.current_stream().wait_stream(s)
+
+            with torch.cuda.graph(self.split_graph):
+                self.static_outputs = self._split_heads(self.input_surface)
+
+        self.input_surface.copy_(fused_qkv)
+        self.split_graph.replay()
+        return tuple(o.detach() for o in self.static_outputs)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        alibi: Optional[torch.Tensor],
+        attention_mask: torch.Tensor,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        use_cache: bool = False,
+        output_attentions: bool = False,
+    ):
+        assert not output_attentions
+
+        fused_qkv = self.query_key_value(hidden_states)  # [batch_size, seq_length, 3 x hidden_size]
+
+        if (
+            self.new_decoder_architecture
+            and hidden_states.size(1) == 1
+            and torch.is_inference_mode_enabled()
+            and hidden_states.device.type == "cuda"
+        ):
+            query_layer, key_layer, value_layer = self._optimized_split_heads(fused_qkv)
+        else:
+            # 3 x [batch_size, seq_length, num_heads, head_dim]
+            (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+        num_kv_heads = self.num_heads
+        batch_size, query_length, _, _ = query_layer.shape
+
+        query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
+        key_layer = key_layer.transpose(1, 2).reshape(
+            batch_size * num_kv_heads,
+            query_length,
+            self.head_dim,
+        )
+        value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
+
+        past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
+        query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
+
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            # concatenate along seq_length dimension:
+            #  - key: [batch_size * self.num_heads, kv_length, head_dim]
+            #  - value: [batch_size * self.num_heads, kv_length, head_dim]
+            key_layer = torch.cat((past_key, key_layer), dim=1)
+            value_layer = torch.cat((past_value, value_layer), dim=1)
+
+        _, kv_length, _ = key_layer.shape
+        if use_cache:
+            present = (key_layer, value_layer)
+        else:
+            present = None
+
+        query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
+        key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
+        value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
+
+        attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
+
+        if alibi is None:
+            attn_output = F.scaled_dot_product_attention(
+                query_layer_, key_layer_, value_layer_, attn_mask=attention_mask_float, dropout_p=0.0, is_causal=False
+            )
+
+            attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
+            attn_output = attn_output.permute(0, 2, 1, 3)
+            attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+
+            output_tensor = self.dense(attn_output)
+
+            return output_tensor, present
+        else:
+            matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
+
+            # change view to [batch_size, num_heads, q_length, kv_length]
+            attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
+
+            # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
+            input_dtype = attention_scores.dtype
+            # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
+            if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
+                attention_scores = attention_scores.to(torch.float32)
+            # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
+            # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
+            # equivalent and more performant, but there might be a numerical difference. If you're reading this
+            # and you'd like to experiment and maybe file a PR, feel free!
+            attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
+            attention_logits *= self.inv_norm_factor
+            attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
+            # [batch_size, num_heads, q_length, kv_length]
+            attention_probs = self.attention_dropout(attention_probs)
+
+            if head_mask is not None:
+                attention_probs = attention_probs * head_mask
+
+            # change view [batch_size, num_heads, q_length, kv_length]
+            attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
+
+            # matmul: [batch_size * num_heads, q_length, head_dim]
+            context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
+
+            # change view [batch_size, q_length, num_heads * head_dim]
+            context_layer = self._merge_heads(context_layer)
+
+            output_tensor = self.dense(context_layer)
+
+            if output_attentions:
+                return output_tensor, present, attention_probs
+            else:
+                return output_tensor, present
+
+
+class OptimizedFalconDecoderLayer(FalconDecoderLayer):
+    def __init__(self, config: FalconConfig):
+        nn.Module.__init__(self)
+        hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+
+        self.mlp = FalconMLP(config)
+        self.hidden_dropout = config.hidden_dropout
+        self.config = config
+
+        self.self_attention = OptimizedFalconAttention(config)
+
+        if self.config.alibi or not config.new_decoder_architecture:
+            if config.new_decoder_architecture:
+                # The layer norm before self-attention
+                self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+                # The layer norm before the MLP
+                self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+            else:
+                self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+                if not config.parallel_attn:
+                    self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+        else:
+            self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+            self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+
+            self.ln_graph = None
+            self.static_input = None
+            self.static_outputs = None
+
+    def _optimized_apply_ln(self, hidden_states):
+        if self.ln_graph is None:
+            self.ln_graph = torch.cuda.CUDAGraph()
+            self.static_input = hidden_states
+
+            s = torch.cuda.Stream()
+            s.wait_stream(torch.cuda.current_stream())
+            with torch.cuda.stream(s):
+                for _ in range(3):
+                    self.ln_attn(hidden_states)
+                    self.ln_mlp(hidden_states)
+            torch.cuda.current_stream().wait_stream(s)
+
+            with torch.cuda.graph(self.ln_graph):
+                ln_attn_output = self.ln_attn(hidden_states)
+                ln_mlp_output = self.ln_mlp(hidden_states)
+                self.static_outputs = (ln_attn_output, ln_mlp_output)
+
+        self.static_input.copy_(hidden_states)
+        self.ln_graph.replay()
+        return tuple(o.detach() for o in self.static_outputs)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        alibi: Optional[torch.Tensor],
+        attention_mask: torch.Tensor,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        head_mask: Optional[torch.Tensor] = None,
+        use_cache: bool = False,
+        output_attentions: bool = False,
+    ):
+        residual = hidden_states
+
+        if self.config.new_decoder_architecture:
+            if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+                attention_layernorm_out, mlp_layernorm_out = self._optimized_apply_ln(hidden_states)
+            else:
+                attention_layernorm_out = self.ln_attn(hidden_states)
+                mlp_layernorm_out = self.ln_mlp(hidden_states)
+        else:
+            attention_layernorm_out = self.input_layernorm(hidden_states)
+
+        attn_outputs = self.self_attention(
+            attention_layernorm_out,
+            layer_past=layer_past,
+            attention_mask=attention_mask,
+            alibi=alibi,
+            head_mask=head_mask,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+        )
+
+        attention_output = attn_outputs[0]
+
+        if not self.config.new_decoder_architecture:
+            if self.config.parallel_attn:
+                mlp_layernorm_out = attention_layernorm_out
+            else:
+                residual = dropout_add(
+                    attention_output, residual, self.config.attention_dropout, training=self.training
+                )
+                mlp_layernorm_out = self.post_attention_layernorm(residual)
+
+        outputs = attn_outputs[1:]
+
+        mlp_output = self.mlp(mlp_layernorm_out)
+
+        if self.config.new_decoder_architecture or self.config.parallel_attn:
+            mlp_output += attention_output
+
+        output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
+
+        if use_cache:
+            outputs = (output,) + outputs
+        else:
+            outputs = (output,) + outputs[1:]
+
+        return outputs  # hidden_states, present, attentions
+
+
+class WrappedFalconBlock(OptimizedFalconDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[KVCache] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ):
+        assert attention_mask is None
+
+        batch_size, seq_length = hidden_states.shape[:2]
+
+        if layer_past is not None:
+            layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
+        past_length = 0 if layer_past is None else layer_past[0].shape[1]
+        seq_length_with_past = seq_length + past_length
+
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+        if alibi is None and self.config.alibi:
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            alibi=alibi,
+            layer_past=layer_past,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        key_states = key_states.permute(0, 2, 1)
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+
+        if self.config.new_decoder_architecture:
+            key_states = self._expand_states(key_states)
+            value_states = self._expand_states(value_states)
+
+        return (key_states, value_states)
+
+    def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        if self.config.new_decoder_architecture:
+            key_states = self._collapse_states(key_states)
+            value_states = self._collapse_states(value_states)
+
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+        key_states = key_states.permute(0, 2, 1)
+
+        return (key_states, value_states)
+
+    def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
+        state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1)  # No copy
+        state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)  # Involves a copy
+        return state
+
+    def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
+        state = state[:, :, 0]
+        state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
+        return state

+ 45 - 0
src/petals/models/falcon/config.py

@@ -0,0 +1,45 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.falcon import FalconConfig
+from transformers.models.falcon.modeling_falcon import FalconAttention
+
+from petals.client.config import ClientConfig
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.models.falcon.block import WrappedFalconBlock
+from petals.utils.auto_config import DefaultRevisionMixin
+
+logger = get_logger(__name__)
+
+
+class DistributedFalconConfig(DefaultRevisionMixin, FalconConfig, ClientConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedFalconBlock
+    attn_class = FalconAttention
+    block_prefix = "transformer.h"
+
+    @property
+    def num_key_value_groups(self) -> int:
+        if self.new_decoder_architecture:
+            return self.num_attention_heads // self.num_kv_heads
+        if self.multi_query:
+            return self.num_attention_heads
+        return 1
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            dht_prefix = str(model_name_or_path)
+            dht_prefix = dht_prefix.split("/")[-1]  # Use only repo name to merge blocks hosted by different accounts
+            dht_prefix = dht_prefix.replace(".", "-")
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+
+        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+        config = result[0] if isinstance(result, tuple) else result
+        if config.pad_token_id is None:
+            config.pad_token_id = 0
+        return result

+ 150 - 0
src/petals/models/falcon/model.py

@@ -0,0 +1,150 @@
+from typing import Optional
+
+import hivemind
+import torch
+import torch.nn as nn
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
+from transformers.models.falcon import (
+    FalconForCausalLM,
+    FalconForSequenceClassification,
+    FalconModel,
+    FalconPreTrainedModel,
+)
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.falcon.config import DistributedFalconConfig
+from petals.utils.auto_config import DefaultRevisionMixin
+
+logger = get_logger(__name__)
+
+
+class DistributedFalconModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, FalconModel):
+    """FalconModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = [r"^transformer\.h\."]
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig, *, dht: Optional[hivemind.DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.h) == 0
+        config.num_hidden_layers = n_layer
+
+        self.h = RemoteSequential(config, dht=dht)
+
+        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[RemotePastKeyValues] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        head_mask: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        # The causal mask will be added on the server-side
+        assert (
+            attention_mask is None or (attention_mask == 1).all()
+        ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
+        assert not output_attentions, f"{output_attentions=} is not supported"
+        assert not output_hidden_states, f"{output_hidden_states=} is not supported"
+        assert return_dict is None or return_dict, f"{return_dict=} is not supported"
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
+        if use_prompts:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        else:
+            prompts = intermediate_prompts = None
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        hidden_states = self.h(
+            hidden_states,
+            prompts=intermediate_prompts,
+            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
+        )
+
+        # Remove prefix
+        if use_prompts:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=RemotePastKeyValues(),
+            hidden_states=None,
+            attentions=None,
+        )
+
+    @property
+    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin
+        return nn.Identity()
+
+
+class DistributedFalconForCausalLM(DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, FalconForCausalLM):
+    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig):
+        FalconPreTrainedModel.__init__(self, config)
+        self.transformer = DistributedFalconModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+
+class DistributedFalconForSequenceClassification(
+    DefaultRevisionMixin, FromPretrainedMixin, FalconForSequenceClassification
+):
+    _keys_to_ignore_on_load_missing = DistributedFalconModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedFalconModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedFalconConfig
+
+    def __init__(self, config: DistributedFalconConfig):
+        FalconPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.transformer = DistributedFalconModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()

+ 1 - 0
src/petals/models/llama/config.py

@@ -43,4 +43,5 @@ class DistributedLlamaConfig(LlamaConfig, ClientConfig, PTuneConfig, LMHeadConfi
         result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
         config = result[0] if isinstance(result, tuple) else result
         config.pretraining_tp = 1  # This may give less accurate results but it doesn't matter if we use quantization
+        config.use_cache = True  # use_cache=False leads to identical results but is slower and not supported by Petals
         return result

+ 3 - 2
src/petals/models/llama/model.py

@@ -73,7 +73,8 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0
+        if use_prompts:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
@@ -90,7 +91,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         )
 
         # Remove prefix
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
         # Add last hidden state

+ 9 - 7
src/petals/server/server.py

@@ -5,6 +5,7 @@ import math
 import multiprocessing as mp
 import os
 import random
+import sys
 import threading
 import time
 from typing import Dict, List, Optional, Sequence, Union
@@ -186,10 +187,7 @@ class Server:
             check_device_balance(self.tensor_parallel_devices)
 
         if quant_type is None:
-            if device.type == "cuda":
-                quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
-            else:
-                quant_type = QuantType.NONE
+            quant_type = QuantType.NF4 if device.type == "cuda" else QuantType.NONE
         self.quant_type = quant_type
         logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
@@ -234,8 +232,9 @@ class Server:
         self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
         logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
 
-        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
-        if throughput in ["auto", "eval"]:
+        assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
+        if throughput in ["auto", "eval", "dry_run"]:
+            force_eval = throughput in ["eval", "dry_run"]
             throughput_info = get_server_throughput(
                 converted_model_name_or_path,
                 self.block_config,
@@ -245,9 +244,12 @@ class Server:
                 quant_type=quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 reachable_via_relay=reachable_via_relay,
-                force_eval=(throughput == "eval"),
+                force_eval=force_eval,
                 cache_dir=cache_dir,
             )
+            if throughput == "dry_run":
+                logger.info("Finished estimating throughput, exiting")
+                sys.exit(0)
         else:
             throughput_info = {"throughput": throughput}
         self.server_info = ServerInfo(

+ 34 - 5
src/petals/utils/auto_config.py

@@ -1,12 +1,14 @@
 import os
-import re
 from dataclasses import dataclass
 from typing import Optional, Type, Union
 
+from hivemind import get_logger
 from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
 
 from petals.utils.hf_auth import always_needs_auth
 
+logger = get_logger(__name__)
+
 
 @dataclass
 class _ModelClasses:
@@ -49,17 +51,44 @@ class _AutoDistributedBase:
         return proper_cls.from_pretrained(model_name_or_path, *args, **kwargs)
 
 
-class AutoDistributedConfig(_AutoDistributedBase):
+class DefaultRevisionMixin:
+    """
+    Petals only supports Falcon loaded in the new in-library format (transformers.FalconModel).
+    TII models were recently converted to this format but then reverted back due to compatibility issues.
+    We chose to support only the new format since HF staff promised to eventually convert these models
+    to the new format again, see https://huggingface.co/tiiuae/falcon-40b/discussions/90#64b4d23bf44fd957492f7602
+    Until it happens, we override the default `main` revision for the TII repos with the commit
+    pointing out to the model in the in-library format.
+    """
+
+    DEFAULT_REVISIONS = {
+        "tiiuae/falcon-40b": "f1ba7d328c06aa6fbb4a8afd3c756f46d7e6b232",
+        "tiiuae/falcon-40b-instruct": "7475ff8cfc36ed9a962b658ae3c33391566a85a5",
+        "tiiuae/falcon-7b": "4e2d06f0a7c6370ebabbc30c6f59377ae8f73d76",
+        "tiiuae/falcon-7b-instruct": "f8dac3fff96d5debd43edf56fb4e1abcfffbef28",
+    }
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, revision: Optional[str] = None, **kwargs
+    ):
+        if revision is None and model_name_or_path in cls.DEFAULT_REVISIONS:
+            revision = cls.DEFAULT_REVISIONS[model_name_or_path]
+            logger.info(f"Loading {model_name_or_path}, revision {revision}")
+        return super().from_pretrained(model_name_or_path, *args, revision=revision, **kwargs)
+
+
+class AutoDistributedConfig(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "config"
 
 
-class AutoDistributedModel(_AutoDistributedBase):
+class AutoDistributedModel(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model"
 
 
-class AutoDistributedModelForCausalLM(_AutoDistributedBase):
+class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_causal_lm"
 
 
-class AutoDistributedModelForSequenceClassification(_AutoDistributedBase):
+class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
     _mapping_field = "model_for_sequence_classification"

+ 128 - 0
tests/test_optimized_layers.py

@@ -0,0 +1,128 @@
+from typing import Optional, Tuple
+
+import pytest
+import torch
+from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
+
+from petals.utils.auto_config import AutoDistributedConfig
+from petals.utils.convert_block import QuantType, convert_block
+from test_utils import MODEL_NAME
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[KVCache] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ):
+        batch_size, seq_length = hidden_states.shape[:2]
+
+        if layer_past is not None:
+            layer_past = self._reorder_cache_from_bloom_to_falcon(layer_past)
+        past_length = 0 if layer_past is None else layer_past[0].shape[1]
+        seq_length_with_past = seq_length + past_length
+
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+        if alibi is None and self.config.alibi:
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            alibi=alibi,
+            layer_past=layer_past,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        key_states = key_states.permute(0, 2, 1)
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+
+        if self.config.new_decoder_architecture:
+            key_states = self._expand_states(key_states)
+            value_states = self._expand_states(value_states)
+
+        return (key_states, value_states)
+
+    def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
+        key_states, value_states = key_value
+
+        if self.config.new_decoder_architecture:
+            key_states = self._collapse_states(key_states)
+            value_states = self._collapse_states(value_states)
+
+        assert key_states.shape == value_states.shape  # Both are [batch_size * num_kv_heads, seq_len, head_dim]
+        key_states = key_states.permute(0, 2, 1)
+
+        return (key_states, value_states)
+
+    def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_kv_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_kv_heads // self.config.num_kv_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, 1, seq_len, head_dim)
+        state = state.expand(-1, -1, self.config.num_key_value_groups, -1, -1)  # No copy
+        state = state.reshape(batch_size * self.config.num_attention_heads, seq_len, head_dim)  # Involves a copy
+        return state
+
+    def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
+        batch_size_x_num_attn_heads, seq_len, head_dim = state.shape
+        batch_size = batch_size_x_num_attn_heads // self.config.num_attention_heads
+
+        state = state.view(batch_size, self.config.num_kv_heads, self.config.num_key_value_groups, seq_len, head_dim)
+        state = state[:, :, 0]
+        state = state.view(batch_size * self.config.num_kv_heads, seq_len, head_dim)
+        return state
+
+
+@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
+@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
+@pytest.mark.forked
+def test_falcon(device):
+    if device == "cuda:0" and not torch.cuda.is_available():
+        pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
+
+    config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
+
+    tensor_parallel_devices = (device,)
+    dtype = torch.bfloat16
+    quant_type = QuantType.NONE
+
+    block = config.block_class(config).to(dtype)
+    block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+
+    unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
+    unopt_block = convert_block(
+        unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
+    )
+
+    unopt_block.load_state_dict(block.state_dict())
+    cache = unopt_cache = None
+
+    with torch.inference_mode():
+        for length in [10, 1, 1, 1]:
+            dummy_input = torch.randn(1, length, config.hidden_size, device=device, dtype=dtype)
+            block_output, cache = block(dummy_input, layer_past=cache, use_cache=True)
+            unopt_block_output, unopt_cache = unopt_block(dummy_input, layer_past=unopt_cache, use_cache=True)
+            assert torch.allclose(block_output, unopt_block_output, atol=1e-6, rtol=0), length
+            assert torch.allclose(cache[0], unopt_cache[0], atol=1e-6, rtol=0), length
+            assert torch.allclose(cache[1], unopt_cache[1], atol=1e-6, rtol=0), length