Browse Source

Bump transformers to 4.25.1 (#151)

- latest accelerate, transformers, huggingface_hub
- rearrange attention caches to support https://github.com/huggingface/transformers/pull/18344
- remove unused code
- fix edge case where session crashes when receiving seq length 0
- assert transformer version when importing WrappedBloomBlock

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 2 years ago
parent
commit
b04982c1a2

+ 3 - 3
setup.cfg

@@ -33,9 +33,9 @@ python_requires = >=3.7
 install_requires =
     torch>=1.12
     bitsandbytes==0.34.0
-    accelerate==0.10.0
-    huggingface-hub==0.7.0
-    transformers==4.21.3
+    accelerate==0.15.0
+    huggingface-hub==0.11.1
+    transformers==4.25.1
     protobuf>=3.20.3,<4.0dev
     hivemind==1.1.3
     humanfriendly

+ 0 - 2
src/petals/bloom/__init__.py

@@ -1,2 +0,0 @@
-from petals.bloom.block import BloomBlock
-from petals.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel

+ 40 - 236
src/petals/bloom/block.py

@@ -3,253 +3,57 @@ Bloom intermediate layer
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-import math
+import os
+from typing import Optional, Tuple
 
-import torch
-import torch.nn as nn
 import torch.nn.quantized.dynamic.modules.linear
+import transformers
+from transformers.models.bloom.modeling_bloom import BloomBlock, _expand_mask, _make_causal_mask, build_alibi_tensor
 
-from petals.bloom.ops import (
-    BloomGelu,
-    BloomScaledSoftmax,
-    attention_mask_func,
-    build_alibi_tensor,
-    dropout_add,
-    pre_process_alibi_for_pad,
-    split_tensor_along_last_dim,
-)
+if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
+    assert transformers.__version__.startswith("4.25."), "Please install transformers 4.25.1"
 
 
-class BloomAttention(nn.Module):
-    def __init__(self, config, layer_number=None):
-        super().__init__()
-
-        self.hidden_size = config.hidden_size
-        self.num_heads = config.n_head
-        self.head_dim = self.hidden_size // self.num_heads
-        self.split_size = self.hidden_size
-        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
-        self.masked_softmax_fusion = config.masked_softmax_fusion
-        self.hidden_dropout = config.hidden_dropout
-
-        if self.head_dim * self.num_heads != self.hidden_size:
-            raise ValueError(
-                f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
-                f" {self.num_heads})."
-            )
-
-        # Layer-wise attention scaling
-        self.layer_number = max(1, layer_number)
-        self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
-
-        # Scaled Softmax
-        self.scale_mask_softmax = BloomScaledSoftmax(
-            self.masked_softmax_fusion,
-            attention_mask_func,
-            self.attention_softmax_in_fp32,
-            self.layer_number,
-        )
-
-        self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
-        self.dense = nn.Linear(self.hidden_size, self.hidden_size)
-
-        self.attention_dropout = nn.Dropout(config.attention_dropout)
-
+class WrappedBloomBlock(BloomBlock):
     def forward(
         self,
-        hidden_states,
-        residual,
-        layer_past=None,
-        attention_mask=None,
-        alibi=None,
-        head_mask=None,
-        use_cache=False,
-        output_attentions=False,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        alibi: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        **kwargs
     ):
+        assert attention_mask is None
+        batch_size, seq_length = hidden_states.shape[:2]
+        past_length = 0 if layer_past is None else layer_past[0].shape[-1]
+        seq_length_with_past = seq_length + past_length
+        attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
         if alibi is None:
-            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
-            alibi = build_alibi_tensor(
-                current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
-            )
-
-        # hidden_states: [batch_size, seq_length, hidden_size]
-        # apply preprocessing if the input is padded
-        if attention_mask is not None:
-            alibi = pre_process_alibi_for_pad(alibi, attention_mask)
-        # otherwise repeat alibi tensor with the batch size
-        else:
-            alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
-
-        mixed_x_layer = self.query_key_value(hidden_states)
-
-        # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
-        new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
-        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
-
-        # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3  [batch_size, seq_length, num_heads, head_dim]
-        (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
-
-        if layer_past is not None:
-            past_key, past_value = layer_past
-            key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
-            value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
-
-        if use_cache is True:
-            present = (key_layer, value_layer)
-        else:
-            present = None
-
-        # [batch_size, head_dim, q_length, k_length]
-        output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
-
-        # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
-        query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
-
-        # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
-        key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
-
-        # Raw attention scores. [batch_size * num_heads, q_length, k_length]
-        beta = 1.0 / self.layer_number
-
-        matmul_result = torch.baddbmm(
-            alibi,
-            query_layer.transpose(1, 0),
-            key_layer.transpose(1, 0).transpose(1, 2),
-            beta=beta,
-            alpha=(1.0 / self.norm_factor),
+            alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
+        attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
+        return super().forward(
+            hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )
 
-        # change view to [batch_size, num_heads, q_length, k_length]
-        attention_scores = matmul_result.view(*output_size)
-
-        # attention scores and attention mask [b, np, sq, sk]
-        max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
-        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
-        attention_probs = self.attention_dropout(attention_probs)
-
-        if head_mask is not None:
-            attention_probs = attention_probs * head_mask
-
-        # context layer shape: [batch_size, num_heads, q_length, head_dim]
-        output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
-
-        # change view [k_length, batch_size x num_heads, head_dim]
-        value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
-
-        # change view [batch_size x num_heads, q_length, k_length]
-        attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
-
-        # matmul: [batch_size * num_heads, q_length, head_dim]
-        context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
-
-        # change view [batch_size, num_heads, q_length, head_dim]
-        context_layer = context_layer.view(*output_size)
-
-        # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
-        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
-
-        # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
-        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
-
-        context_layer = context_layer.view(*new_context_layer_shape)
-
-        # Output. [q_length, batch_size, hidden_size]
-
-        # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
-        output_tensor = self.dense(context_layer)
-        output = output_tensor.transpose(1, 0)
-
-        output = dropout_add(output, residual, self.hidden_dropout, self.training)
-
-        outputs = (output, present)
-        if output_attentions:
-            outputs += (attention_probs,)
-
-        return outputs
-
-
-class BloomMLP(nn.Module):
-    def __init__(self, config):
-        super().__init__()
-        self.hidden_size = config.hidden_size
-        self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
-        self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
-        self.hidden_dropout = config.hidden_dropout
-        self.gelu_impl = BloomGelu()
-
-    def forward(self, hidden_states, residual):
-        hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
-        intermediate_output = self.dense_4h_to_h(hidden_states)
-        output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
-        return output
-
-
-class BloomBlock(nn.Module):
-    def __init__(self, config, layer_number=None):
-        super().__init__()
-        self.hidden_size = config.hidden_size
-
-        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
-        self.n_head = config.n_head
-        self.self_attention = BloomAttention(config, layer_number=layer_number)
-        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
-
-        self.mlp = BloomMLP(config)
-
-        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
-        self.hidden_dropout = config.hidden_dropout
-
-    def forward(
-        self,
-        hidden_states,
-        layer_past=None,
-        attention_mask=None,
-        head_mask=None,
-        use_cache=False,
-        output_attentions=False,
-        alibi=None,
-    ):
-        # hidden_states: [batch_size, seq_length, hidden_size]
-
-        # Layer norm at the beginning of the transformer layer.
-        layernorm_output = self.input_layernorm(hidden_states)
-
-        # Layer norm post the self attention.
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = hidden_states
+    def _prepare_attn_mask(
+        self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
+    ) -> torch.BoolTensor:
+        # create causal mask
+        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
+        combined_attention_mask = None
+        device = attention_mask.device
+        _, src_length = input_shape
+
+        if src_length > 1:
+            combined_attention_mask = _make_causal_mask(
+                torch.Size(input_shape), device=device, past_key_values_length=past_key_values_length
+            )
 
-        # Self attention.
-        attn_outputs = self.self_attention(
-            layernorm_output,
-            residual,
-            layer_past=layer_past,
-            attention_mask=attention_mask,
-            alibi=alibi,
-            head_mask=head_mask,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
+        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
+        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
+        combined_attention_mask = (
+            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
         )
 
-        attention_output = attn_outputs[0]
-
-        outputs = attn_outputs[1:]
-
-        layernorm_output = self.post_attention_layernorm(attention_output)
-
-        # Get residual
-        if self.apply_residual_connection_post_layernorm:
-            residual = layernorm_output
-        else:
-            residual = attention_output
-
-        # MLP.
-        output = self.mlp(layernorm_output, residual)
-
-        if use_cache:
-            outputs = (output,) + outputs
-        else:
-            outputs = (output,) + outputs[1:]
-
-        return outputs  # hidden_states, present, attentions
+        return combined_attention_mask

+ 12 - 21
src/petals/bloom/from_pretrained.py

@@ -13,9 +13,10 @@ from typing import Optional, OrderedDict, Union
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from transformers.modeling_utils import WEIGHTS_NAME
-from transformers.utils.hub import cached_path, hf_bucket_url
+from transformers.models.bloom.configuration_bloom import BloomConfig
+from transformers.utils import get_file_from_repo
 
-from petals.bloom import BloomBlock, BloomConfig
+from petals.bloom.block import WrappedBloomBlock
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 use_hivemind_log_handler("in_root_logger")
@@ -23,10 +24,6 @@ logger = get_logger(__file__)
 
 CLIENT_BRANCH = "main"
 BLOCK_BRANCH_PREFIX = "block_"
-USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
-FORCE_DOWNLOAD = False
-RESUME_DOWNLOAD = False
-LOCAL_FILES_ONLY = False
 
 
 def load_pretrained_block(
@@ -36,15 +33,15 @@ def load_pretrained_block(
     torch_dtype: Union[torch.dtype, str] = "auto",
     use_auth_token: Optional[str] = None,
     cache_dir: Optional[str] = None,
-) -> BloomBlock:
-    """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
+) -> WrappedBloomBlock:
+    """Load one BLOOM block from a converted model. See convert_model.py (or README.md) on how to convert it."""
 
     if config is None:
         config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
 
-    block = BloomBlock(config, layer_number=block_index)
+    block = WrappedBloomBlock(config)
     state_dict = _load_state_dict(
         converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
     )
@@ -70,20 +67,14 @@ def _load_state_dict(
     cache_dir: Optional[str] = None,
 ) -> OrderedDict[str, torch.Tensor]:
     revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
-    archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
-
-    # Load from URL or cache if already cached
-    resolved_archive_file = cached_path(
-        archive_file,
-        cache_dir=cache_dir,
-        force_download=FORCE_DOWNLOAD,
-        proxies=None,
-        resume_download=RESUME_DOWNLOAD,
-        local_files_only=LOCAL_FILES_ONLY,
+    archive_file = get_file_from_repo(
+        pretrained_model_name_or_path,
+        filename=WEIGHTS_NAME,
+        revision=revision,
         use_auth_token=use_auth_token,
-        user_agent=USER_AGENT,
+        cache_dir=cache_dir,
     )
-    state_dict = torch.load(resolved_archive_file, map_location="cpu")
+    state_dict = torch.load(archive_file, map_location="cpu")
     return state_dict
 
 

+ 0 - 595
src/petals/bloom/model.py

@@ -1,595 +0,0 @@
-"""
-PyTorch BLOOM model that implements several memory-efficient modes.
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.nn.functional as F
-import torch.utils.checkpoint
-from hivemind import use_hivemind_log_handler
-from torch import nn
-from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
-from transformers.file_utils import (
-    add_code_sample_docstrings,
-    add_start_docstrings,
-    add_start_docstrings_to_model_forward,
-)
-from transformers.modeling_outputs import (
-    BaseModelOutputWithPastAndCrossAttentions,
-    CausalLMOutputWithCrossAttentions,
-    SequenceClassifierOutputWithPast,
-)
-from transformers.models.bloom.configuration_bloom import BloomConfig
-from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
-from transformers.utils import logging
-
-from petals.bloom.block import BloomBlock
-
-use_hivemind_log_handler("in_root_logger")
-logger = logging.get_logger(__file__)
-
-_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
-_CONFIG_FOR_DOC = "BloomConfig"
-_TOKENIZER_FOR_DOC = "BloomTokenizer"
-
-
-BLOOM_START_DOCSTRING = r"""
-
-    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
-    library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
-
-    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
-    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
-    and behavior.
-
-    Parameters:
-        config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
-            Initializing with a config file does not load the weights associated with the model, only the
-            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
-"""
-
-BLOOM_INPUTS_DOCSTRING = r"""
-    Args:
-        input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
-            `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
-            `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
-            sequence tokens in the vocabulary.
-
-            If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
-            `input_ids`.
-
-            Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
-            [`PreTrainedTokenizer.__call__`] for details.
-
-            [What are input IDs?](../glossary#input-ids)
-        past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
-            Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
-            `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
-            their past given to this model should not be passed as `input_ids` as they have already been computed.
-        attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
-
-            - 1 for tokens that are **not masked**,
-            - 0 for tokens that are **masked**.
-
-            [What are attention masks?](../glossary#attention-mask)
-        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
-            config.max_position_embeddings - 1]`.
-
-            [What are position IDs?](../glossary#position-ids)
-        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
-            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
-
-            - 1 indicates the head is **not masked**,
-            - 0 indicates the head is **masked**.
-
-        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
-            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
-            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
-            model's internal embedding lookup matrix.
-
-            If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
-            `past_key_values`).
-        use_cache (`bool`, *optional*):
-            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
-            `past_key_values`).
-        output_attentions (`bool`, *optional*):
-            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
-            tensors for more detail.
-        output_hidden_states (`bool`, *optional*):
-            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
-            more detail.
-        return_dict (`bool`, *optional*):
-            Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
-"""
-
-
-class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
-    @classmethod
-    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
-        if low_cpu_mem_usage is None:
-            low_cpu_mem_usage = True
-        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
-
-    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
-        "low_cpu_mem_usage(`bool`, *optional*)",
-        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
-    )
-
-
-@add_start_docstrings(
-    "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
-    BLOOM_START_DOCSTRING,
-)
-class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
-    def __init__(self, config):
-        super().__init__(config)
-        assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
-
-        self.embed_dim = config.hidden_size
-        self.n_head = config.n_head
-
-        # Embedding + LN Embedding
-        self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
-        self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
-
-        # Transformer blocks
-        self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
-
-        # Final Layer Norm
-        self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
-
-        self.gradient_checkpointing = False
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_input_embeddings(self):
-        return self.word_embeddings
-
-    def set_input_embeddings(self, new_embeddings):
-        self.word_embeddings = new_embeddings
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=BaseModelOutputWithPastAndCrossAttentions,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-        output_hidden_states = (
-            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
-        )
-        use_cache = use_cache if use_cache is not None else self.config.use_cache
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        if input_ids is not None and inputs_embeds is not None:
-            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-        if position_ids is not None:
-            logger.warning("position_ids are ignored in this bloom implementation")
-        elif input_ids is not None:
-            input_shape = input_ids.size()
-            input_ids = input_ids.view(-1, input_shape[-1])
-        elif inputs_embeds is not None:
-            input_shape = inputs_embeds.size()[:-1]
-        else:
-            raise ValueError("You have to specify either input_ids or inputs_embeds")
-
-        if past_key_values is None:
-            past_key_values = tuple([None] * len(self.h))
-
-        # Prepare head mask if needed
-        # 1.0 in head_mask indicate we keep the head
-        # attention_probs has shape bsz x n_head x N x N
-        # head_mask has shape n_layer x batch x n_head x N x N
-        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        # Note: it supports only float32 or bfloat16 inputs
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
-
-        output_shape = input_shape + (hidden_states.size(-1),)
-
-        presents = () if use_cache else None
-        all_self_attentions = () if output_attentions else None
-        all_hidden_states = () if output_hidden_states else None
-
-        # Compute alibi tensor: check build_alibi_tensor documentation
-        current_sequence_length = hidden_states.shape[1]
-        if past_key_values and past_key_values[0]:
-            current_sequence_length += past_key_values[0][0].shape[1]
-
-        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
-
-            if output_hidden_states:
-                all_hidden_states = all_hidden_states + (hidden_states,)
-
-            if self.gradient_checkpointing and self.training:
-
-                if use_cache:
-                    logger.warning(
-                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
-                    )
-                    use_cache = False
-
-                def create_custom_forward(module):
-                    def custom_forward(*inputs):
-                        # None for past_key_value
-                        return module(*inputs, use_cache, output_attentions, alibi=None)
-
-                    return custom_forward
-
-                outputs = torch.utils.checkpoint.checkpoint(
-                    create_custom_forward(block),
-                    hidden_states,
-                    None,
-                    attention_mask,
-                    head_mask[i],
-                )
-            else:
-                outputs = block(
-                    hidden_states,
-                    layer_past=layer_past,
-                    attention_mask=attention_mask,
-                    head_mask=head_mask[i],
-                    use_cache=use_cache,
-                    output_attentions=output_attentions,
-                    alibi=None,
-                )
-
-            hidden_states = outputs[0]
-            if use_cache is True:
-                presents = presents + (outputs[1],)
-
-            if output_attentions:
-                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
-
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        hidden_states = hidden_states.view(output_shape)
-
-        if not return_dict:
-            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
-
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=presents,
-            hidden_states=all_hidden_states,
-            attentions=all_self_attentions,
-        )
-
-
-@add_start_docstrings(
-    """
-    The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
-    embeddings).
-    """,
-    BLOOM_START_DOCSTRING,
-)
-class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-
-    def __init__(self, config):
-        super().__init__(config)
-        self.transformer = BloomModel(config)
-        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    def get_output_embeddings(self):
-        return self.lm_head
-
-    def set_output_embeddings(self, new_embeddings):
-        self.lm_head = new_embeddings
-
-    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
-        # only last token for inputs_ids if past is defined in kwargs
-        if past:
-            input_ids = input_ids[:, -1].unsqueeze(-1)
-
-        attention_mask = kwargs.get("attention_mask", None)
-        position_ids = kwargs.get("position_ids", None)
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past:
-                position_ids = position_ids[:, -1].unsqueeze(-1)
-        else:
-            position_ids = None
-        return {
-            "input_ids": input_ids,
-            "past_key_values": past,
-            "use_cache": kwargs.get("use_cache"),
-            "position_ids": position_ids,
-            "attention_mask": attention_mask,
-        }
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=CausalLMOutputWithCrossAttentions,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
-            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
-            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
-            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
-        """
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        transformer_outputs = self.transformer(
-            input_ids,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            inputs_embeds=inputs_embeds,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-        hidden_states = transformer_outputs[0]
-
-        lm_logits = self.lm_head(hidden_states)
-
-        loss = None
-        if labels is not None:
-            # Shift so that tokens < n predict n
-            shift_logits = lm_logits[..., :-1, :].contiguous()
-            shift_labels = labels[..., 1:].contiguous()
-            # Flatten the tokens
-            loss_fct = CrossEntropyLoss()
-            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
-
-        if not return_dict:
-            output = (lm_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return CausalLMOutputWithCrossAttentions(
-            loss=loss,
-            logits=lm_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )
-
-    @staticmethod
-    def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
-        """
-        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
-        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
-        beam_idx at every generation step.
-        """
-        return tuple(
-            tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
-            for layer_past in past
-        )
-
-
-@add_start_docstrings(
-    """
-    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
-    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
-    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
-    """,
-    BLOOM_START_DOCSTRING,
-)
-class LMHead(nn.Module):
-    def __init__(self, config, word_embeddings: nn.Embedding):
-        super().__init__()
-        self.word_embeddings = word_embeddings
-        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
-
-    @property
-    def in_features(self) -> int:
-        return self.word_embeddings.num_embeddings
-
-    @property
-    def out_features(self) -> int:
-        return self.word_embeddings.embedding_dim
-
-    @property
-    def weight(self):
-        return self.word_embeddings.weight
-
-    @property
-    def bias(self):
-        return None
-
-    def forward(self, hidden_states):
-        word_embeddings = self.word_embeddings.weight
-
-        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
-        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
-            lm_logits = self.chunked_forward(hidden_states)
-        else:
-            # Switch dtype in case word_embeddings are fp16/bf16
-            hidden_states = hidden_states.to(word_embeddings.dtype)
-            lm_logits = F.linear(hidden_states, word_embeddings)
-        return lm_logits
-
-    def chunked_forward(self, hidden_states):
-        """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
-        chunk_size: provides trade-off between efficiency and extra memory consumption.
-        """
-        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
-
-        word_embeddings = self.word_embeddings.weight
-        num_embeddings = self.word_embeddings.num_embeddings
-
-        hidden_states = hidden_states.float()
-        output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
-
-        for i in range(0, num_embeddings, self.chunk_size):
-            chunk = word_embeddings[i : i + self.chunk_size].float()
-            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
-        return output
-
-
-@add_start_docstrings(
-    """
-    The Bloom Model transformer with a sequence classification head on top (linear layer).
-    [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
-    (e.g. GPT-1) do.
-    Since it does classification on the last token, it requires to know the position of the last token. If a
-    `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
-    no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
-    padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
-    each row of the batch).
-    """,
-    BLOOM_START_DOCSTRING,
-)
-class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
-    _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
-
-    def __init__(self, config):
-        super().__init__(config)
-        self.num_labels = config.num_labels
-        self.transformer = BloomModel(config)
-        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
-
-        # Initialize weights and apply final processing
-        self.post_init()
-
-    @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
-    @add_code_sample_docstrings(
-        processor_class=_TOKENIZER_FOR_DOC,
-        checkpoint=_CHECKPOINT_FOR_DOC,
-        output_type=SequenceClassifierOutputWithPast,
-        config_class=_CONFIG_FOR_DOC,
-    )
-    def forward(
-        self,
-        input_ids=None,
-        past_key_values=None,
-        attention_mask=None,
-        position_ids=None,
-        head_mask=None,
-        inputs_embeds=None,
-        labels=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
-        r"""
-        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
-            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
-            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
-            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
-        """
-
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        transformer_outputs = self.transformer(
-            input_ids,
-            past_key_values=past_key_values,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            head_mask=head_mask,
-            inputs_embeds=inputs_embeds,
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict,
-        )
-
-        hidden_states = transformer_outputs[0]
-        logits = self.score(hidden_states)
-
-        if input_ids is not None:
-            batch_size = input_ids.shape[0]
-        else:
-            batch_size = inputs_embeds.shape[0]
-
-        if self.config.pad_token_id is None and batch_size != 1:
-            raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
-        if self.config.pad_token_id is None:
-            sequence_lengths = -1
-        else:
-            if input_ids is not None:
-                sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
-            else:
-                sequence_lengths = -1
-                logger.warning(
-                    f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
-                    "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
-                )
-
-        pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
-        loss = None
-        if labels is not None:
-            if self.config.problem_type is None:
-                if self.num_labels == 1:
-                    self.config.problem_type = "regression"
-                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
-                    self.config.problem_type = "single_label_classification"
-                else:
-                    self.config.problem_type = "multi_label_classification"
-
-            if self.config.problem_type == "regression":
-                loss_fct = MSELoss()
-                if self.num_labels == 1:
-                    loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
-                else:
-                    loss = loss_fct(pooled_logits, labels)
-            elif self.config.problem_type == "single_label_classification":
-                loss_fct = CrossEntropyLoss()
-                loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
-            elif self.config.problem_type == "multi_label_classification":
-                loss_fct = BCEWithLogitsLoss()
-                loss = loss_fct(pooled_logits, labels)
-        if not return_dict:
-            output = (pooled_logits,) + transformer_outputs[1:]
-            return ((loss,) + output) if loss is not None else output
-
-        return SequenceClassifierOutputWithPast(
-            loss=loss,
-            logits=pooled_logits,
-            past_key_values=transformer_outputs.past_key_values,
-            hidden_states=transformer_outputs.hidden_states,
-            attentions=transformer_outputs.attentions,
-        )

+ 74 - 0
src/petals/bloom/modeling_utils.py

@@ -0,0 +1,74 @@
+"""
+PyTorch BLOOM model that implements several memory-efficient modes.
+Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
+See commit history for authorship.
+"""
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+from hivemind import use_hivemind_log_handler
+from torch import nn
+from transformers import BloomConfig
+from transformers.utils import logging
+
+use_hivemind_log_handler("in_root_logger")
+logger = logging.get_logger(__file__)
+
+
+class LMHead(nn.Module):
+    """
+    The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
+    embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
+    In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
+    """
+
+    def __init__(self, config: BloomConfig, word_embeddings: nn.Embedding):
+        super().__init__()
+        self.word_embeddings = word_embeddings
+        self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
+
+    @property
+    def in_features(self) -> int:
+        return self.word_embeddings.num_embeddings
+
+    @property
+    def out_features(self) -> int:
+        return self.word_embeddings.embedding_dim
+
+    @property
+    def weight(self):
+        return self.word_embeddings.weight
+
+    @property
+    def bias(self):
+        return None
+
+    def forward(self, hidden_states):
+        word_embeddings = self.word_embeddings.weight
+
+        # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
+        if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
+            lm_logits = self.chunked_forward(hidden_states)
+        else:
+            # Switch dtype in case word_embeddings are fp16/bf16
+            hidden_states = hidden_states.to(word_embeddings.dtype)
+            lm_logits = F.linear(hidden_states, word_embeddings)
+        return lm_logits
+
+    def chunked_forward(self, hidden_states):
+        """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
+        chunk_size: provides trade-off between efficiency and extra memory consumption.
+        """
+        assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
+
+        word_embeddings = self.word_embeddings.weight
+        num_embeddings = self.word_embeddings.num_embeddings
+
+        hidden_states = hidden_states.float()
+        output = torch.empty(*hidden_states.shape[:-1], num_embeddings)
+
+        for i in range(0, num_embeddings, self.chunk_size):
+            chunk = word_embeddings[i : i + self.chunk_size].float()
+            output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
+        return output

+ 0 - 242
src/petals/bloom/ops.py

@@ -1,242 +0,0 @@
-"""
-Utility operations used in the the BLOOM model
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-import math
-
-import torch
-import torch.autograd
-import torch.nn.functional as F
-from torch import nn
-
-
-def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
-    """Split a tensor along its last dimension.
-
-    Args:
-        tensor: ([`torch.tensor`], *required*):
-            input tensor to split
-        num_partitions ([`int`], *required*):
-            number of partitions to split the tensor
-        contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
-            If True, make each chunk contiguous in memory.
-    """
-    # Get the size and dimension.
-    last_dim = tensor.dim() - 1
-    numerator, denominator = tensor.size()[last_dim], num_partitions
-    if not (numerator % denominator == 0):
-        raise ValueError(f"{numerator} is not divisible by {denominator}")
-    last_dim_size = numerator // denominator
-    # Split.
-    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
-    # Note: torch.split does not create contiguous tensors by default.
-    if contiguous_split_chunks:
-        return tuple(chunk.contiguous() for chunk in tensor_list)
-
-    return tensor_list
-
-
-def attention_mask_func(attention_scores, attention_mask, causal_mask):
-    if attention_mask.dtype == torch.bool:
-        attention_mask_bool = ~attention_mask
-    else:
-        attention_mask_bool = (1 - attention_mask).bool()
-
-    query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
-    padded_causal_mask = (
-        attention_mask_bool[:, None, key_length - query_length : key_length, None]
-        + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
-    ).bool()
-    padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
-    # Make use of floats
-    return (
-        attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
-        padded_causal_mask,
-    )
-
-
-def build_alibi_tensor(
-    max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
-) -> torch.Tensor:
-    """
-    Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
-    relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
-    `softmax(l+a) = softmax(l)`. Based on
-    https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
-    Args:
-    Returns tensor shaped (n_head, 1, max_seq_len)
-        max_seq_len: (`int`, *required*):
-            max sequence length
-        n_head: (`int`, *required*):
-            number of heads
-        dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
-            dtype of the output tensor
-        device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
-            device of the output alibi tensor
-    """
-    closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
-    base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
-    powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
-    slopes = torch.pow(base, powers)
-
-    if closest_power_of_2 != n_head:
-        extra_base = torch.tensor(
-            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
-        )
-        num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
-        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
-        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
-
-    lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
-    return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
-
-
-def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
-    """
-    Args:
-    Pre-process the alibi tensor for padding.
-        alibi: ([`torch.tensor`], *required*):
-            alibi tensor to pre-process
-        attention_mask: ([`torch.tensor`], *required*):
-            attention mask to pre-process
-    """
-    assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
-    unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
-    # ^-- [batch, max_len], values correspond to element indices after removing padding
-    # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
-    alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
-    return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
-
-
-def dropout_add(x, residual, prob, training):
-    """
-    Dropout add function
-
-    Args:
-        x (`torch.tensor`, *required*):
-            input tensor
-        residual (`torch.tensor`, *required*):
-            esidual tensor
-        prob (`float`, *required*):
-            dropout probability
-        training (`bool`, *required*):
-            training mode
-    """
-    out = nn.functional.dropout(x, p=prob, training=training)
-    out = residual + out
-    return out
-
-
-def bloom_gelu_forward(x):
-    """
-    Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
-    make the model jitable.
-
-    Args:
-        x (`torch.tensor`, *required*):
-            input hidden states
-    """
-    return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
-
-
-def bloom_gelu_back(g, x):
-    """
-    gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
-    0.3989423 * x * torch.exp(-0.5 * x * x)
-
-    Args:
-        g (`torch.tensor`, *required*):
-            gradient output tensor
-        x (`torch.tensor`, *required*):
-            input tensor
-    """
-    x = x[0]  # x is a tuple of 1 element, needs to unpack it first
-    tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
-    # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
-    ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
-    return ff * g
-
-
-class GeLUFunction(torch.autograd.Function):
-    @staticmethod
-    def forward(ctx, input):
-        ctx.save_for_backward(input)
-        return bloom_gelu_forward(input)
-
-    @staticmethod
-    def backward(ctx, grad_output):
-        input = ctx.saved_tensors
-        tmp = bloom_gelu_back(grad_output, input)
-        return tmp
-
-
-class BloomGelu(nn.Module):
-    """
-    BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
-    torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
-    copied from Megatron-DeepSpeed code and adapted for our needs
-
-    See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
-
-    """
-
-    def __init__(self):
-        super().__init__()
-
-    def forward(self, x):
-        if self.training:
-            return GeLUFunction.apply(x)
-        else:
-            return bloom_gelu_forward(x)
-
-
-class BloomScaledSoftmax(nn.Module):
-    """
-    fused operation: scaling + mask + softmax
-
-    Args:
-        scaled_masked_softmax_fusion (`bool`, *required*):
-            flag to indicate user want to use softmax fusion
-        mask_func (`function`, *required*):
-            mask function to be applied.
-        softmax_in_fp32 (`bool`, *required*):
-            if true, softmax in performed at fp32 precision.
-        scale (`float`, *required*):
-            scaling factor used in input tensor scaling.
-    """
-
-    def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
-        super().__init__()
-        self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
-        self.mask_func = mask_func
-        self.softmax_in_fp32 = softmax_in_fp32
-        self.scale = scale
-
-        if not (self.scale is None or softmax_in_fp32):
-            raise ValueError("softmax should be in fp32 when scaled")
-
-    def forward(self, input, mask, max_positions):
-        input_dtype = input.dtype
-        input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
-        softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
-
-        if self.scale is not None:
-            input = input * self.scale
-
-        if mask is None:
-            mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
-
-        mask = mask.to(input.device)
-        causal_mask = (
-            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
-            .view(1, 1, max_positions, max_positions)
-            .to(input.device)
-        )
-        mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
-        probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
-
-        if input_in_16bit and self.softmax_in_fp32:
-            probs = probs.to(dtype=input_dtype)
-
-        return probs

+ 1 - 1
src/petals/cli/convert_model.py

@@ -8,8 +8,8 @@ import transformers
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from huggingface_hub import Repository
 from tqdm.auto import tqdm
+from transformers.models.bloom.modeling_bloom import BloomModel
 
-from petals.bloom import BloomModel
 from petals.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
 from petals.client import DistributedBloomConfig
 

+ 3 - 4
src/petals/cli/inference_one_block.py

@@ -3,10 +3,10 @@ import argparse
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from tqdm.auto import trange
+from transformers import BloomConfig
+from transformers.models.bloom.modeling_bloom import build_alibi_tensor
 
 from petals.bloom.block import BloomBlock
-from petals.bloom.model import BloomConfig
-from petals.bloom.ops import build_alibi_tensor
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -31,7 +31,6 @@ if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
     parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
     parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
-    parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
     parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
     parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
     args = parser.parse_args()
@@ -40,7 +39,7 @@ if __name__ == "__main__":
         args.device = "cuda" if torch.cuda.is_available() else "cpu"
 
     config = BloomConfig.from_json_file(args.config)
-    block = BloomBlock(config, args.layer_index).to(args.device)
+    block = BloomBlock(config).to(args.device)
 
     cache = None
 

+ 19 - 6
src/petals/client/remote_model.py

@@ -7,15 +7,15 @@ import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, loglevel, use_hivemind_log_handler
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-
-from petals.bloom.model import (
+from transformers.models.bloom import (
     BloomConfig,
     BloomForCausalLM,
     BloomForSequenceClassification,
     BloomModel,
     BloomPreTrainedModel,
-    LMHead,
 )
+
+from petals.bloom.modeling_utils import LMHead
 from petals.client.remote_generation import RemoteGenerationMixin
 from petals.client.remote_sequential import RemoteSequential
 from petals.constants import PUBLIC_INITIAL_PEERS
@@ -66,7 +66,20 @@ def force_non_empty_weights():
         nn.Module.register_parameter = possibly_patched_register_parameter
 
 
-class DistributedBloomModel(BloomModel):
+class _LowCPUMemoryMixin:
+    @classmethod
+    def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
+        if low_cpu_mem_usage is None:
+            low_cpu_mem_usage = True
+        return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
+
+    from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
+        "low_cpu_mem_usage(`bool`, *optional*)",
+        "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
+    )
+
+
+class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
     """BloomModel, but all transformer layers are hosted by the swarm"""
 
     _keys_to_ignore_on_load_missing = BloomModel._keys_to_ignore_on_load_missing + [
@@ -192,7 +205,7 @@ class DistributedBloomModel(BloomModel):
         )
 
 
-class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
+class DistributedBloomForCausalLM(_LowCPUMemoryMixin, RemoteGenerationMixin, BloomForCausalLM):
     """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
 
     _keys_to_ignore_on_load_missing = (
@@ -230,7 +243,7 @@ class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
             self.lm_head.bias[...] = new_lm_head.bias
 
 
-class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
+class DistributedBloomForSequenceClassification(_LowCPUMemoryMixin, BloomForSequenceClassification):
     _keys_to_ignore_on_load_missing = (
         BloomForSequenceClassification._keys_to_ignore_on_load_missing
         + DistributedBloomModel._keys_to_ignore_on_load_missing

+ 1 - 1
src/petals/client/routing/sequence_manager.py

@@ -57,7 +57,7 @@ class RemoteSequenceManager:
         update_period: float = 30,
         request_timeout: float = 30,
         min_backoff: float = 1,
-        ban_timeout: float = 60,
+        ban_timeout: float = 15,
         sequence_info: Optional[RemoteSequenceInfo] = None,
         rpc_info: Optional[dict] = None,
         banned_peers: Optional[Blacklist] = None,

+ 24 - 16
src/petals/server/backend.py

@@ -6,7 +6,7 @@ from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
 from hivemind.moe.server.module_backend import ModuleBackend
 from hivemind.utils import get_logger
 
-from petals.bloom.from_pretrained import BloomBlock
+from petals.bloom.block import WrappedBloomBlock
 from petals.server.memory_cache import MemoryCache
 from petals.server.task_pool import PrioritizedTaskPool
 from petals.utils.misc import is_dummy
@@ -16,11 +16,11 @@ logger = get_logger(__file__)
 
 
 class TransformerBackend(ModuleBackend):
-    """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
+    """A wrapper for a BLOOM block that can process requests for BLOOM layer forward, backward and inference"""
 
     def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: torch.dtype, **kwargs):
         super().__init__(*args, **kwargs)
-        assert isinstance(self.module, BloomBlock)
+        assert isinstance(self.module, WrappedBloomBlock)
         self.memory_cache = memory_cache
         for name, param in self.module.named_parameters():
             assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
@@ -50,6 +50,7 @@ class TransformerBackend(ModuleBackend):
         )
 
     def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
+        num_heads, head_dim = self.module.self_attention.num_heads, self.module.self_attention.head_dim
         with torch.inference_mode():
             attention_cache_handle = int(cache_metadata[0, 0].item())
             prefix_length = int(cache_metadata[0, 1].item())
@@ -59,24 +60,31 @@ class TransformerBackend(ModuleBackend):
             ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
 
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
-                assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
+                batch_size = cache.shape[1]
+                max_length = cache.numel() // (2 * batch_size * head_dim * num_heads)
+                assert isinstance(self.module, WrappedBloomBlock) and cache.shape[0] == 2 and cache.ndim == 3
                 if not is_dummy(hypo_ids):
                     assert hypo_ids.shape[0] == cache.shape[1]
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
-                layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-                logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
-                hidden_states, (new_k, new_v) = self.module.forward(
-                    hidden_states, layer_past=layer_past, use_cache=True
-                )
+                key_cache = cache[0].view(batch_size, num_heads, head_dim, max_length)
+                value_cache = cache[1].view(batch_size, num_heads, max_length, head_dim)
 
-                # todo remove these asserts once we pass all tests
-                new_length = new_v.shape[1]
+                key_past = key_cache.flatten(0, 1)[:, :, :prefix_length]  # [batch * num_heads, head_dim, kv_length]
+                value_past = value_cache.flatten(0, 1)[:, :prefix_length, :]  # [batch * num_heads, kv_length, head_dim]
+                logger.debug(
+                    f"Metadata: {cache_metadata}, past_k.shape={key_past.shape}, past_v.shape={value_past.shape}"
+                )
+                hidden_states, (new_key, new_value) = self.module.forward(
+                    hidden_states, layer_past=(key_past, value_past), use_cache=True
+                )
+                new_length = new_key.shape[-1]
                 assert new_length > prefix_length
-                assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
-                assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
-                assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
-                cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
-                cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
+                assert new_key.shape[0] == key_past.shape[0] and new_value.shape[0] == value_past.shape[0]
+                assert new_key.shape[-1] == new_length and new_value.shape[-2] == new_length
+                new_key = new_key.view(batch_size, num_heads, head_dim, -1)
+                new_value = new_value.view(batch_size, num_heads, -1, head_dim)
+                key_cache[:, :, :, prefix_length:new_length] = new_key[:, :, :, prefix_length:new_length]
+                value_cache[:, :, prefix_length:new_length, :] = new_value[:, :, prefix_length:new_length, :]
                 return (hidden_states,)
 
     def get_pools(self) -> Sequence[PrioritizedTaskPool]:

+ 3 - 3
src/petals/server/block_utils.py

@@ -2,8 +2,9 @@ from typing import Optional, Union
 
 import torch
 from accelerate import init_empty_weights
+from transformers import BloomConfig
 
-from petals.bloom import BloomBlock, BloomConfig
+from petals.bloom.block import WrappedBloomBlock
 
 
 def resolve_block_dtype(config: BloomConfig, dtype: Union[str, torch.dtype]) -> Union[str, torch.dtype]:
@@ -22,7 +23,6 @@ def get_block_size(
     *,
     dtype: Optional[Union[str, torch.dtype]] = None,
     load_in_8bit: Optional[bool] = None,
-    layer_index: int = 0,
     eps: float = 0.01,  # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
 ) -> int:
     if location == "memory":
@@ -31,7 +31,7 @@ def get_block_size(
         ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
 
     with init_empty_weights():
-        block = BloomBlock(config, layer_index)
+        block = WrappedBloomBlock(config)
         n_params = sum(param.numel() for param in block.parameters())
 
     if location == "memory" and load_in_8bit:

+ 5 - 4
src/petals/server/handler.py

@@ -146,6 +146,9 @@ class TransformerConnectionHandler(ConnectionHandler):
                         for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
                             if not is_dummy(prompt):
                                 hidden_states[:, : prompt.shape[1]] += prompt
+                            if hidden_states.numel() == 0:
+                                continue  # user passed a tensor with 0 tokens. This is a special case that occurs, e.g.
+                                # when user wants to pre-allocate cache or check that server *can* allocate that cache
 
                             cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
                             assert isinstance(
@@ -343,10 +346,8 @@ class TransformerConnectionHandler(ConnectionHandler):
             for backend in backends:
                 num_heads = backend.module.self_attention.num_heads
                 head_dim = backend.module.self_attention.head_dim
-
-                descr = TensorDescriptor(size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype)
-                # [key_or_value, batch_size, max_length, num_heads, head_dim]
-
+                descr = TensorDescriptor(size=(2, batch_size, num_heads * head_dim * max_length), dtype=backend.dtype)
+                # ^-- flattened batch-first tensor of both keys and values; based on BLOOM layer_past layout
                 handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(descr)))
                 total_size += descr.numel() * torch.finfo(descr.dtype).bits // 8
 

+ 1 - 1
src/petals/server/server.py

@@ -16,9 +16,9 @@ from hivemind.moe.server.layers import add_custom_models_from_file
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from transformers import BloomConfig
 
 from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
-from petals.bloom.model import BloomConfig
 from petals.constants import PUBLIC_INITIAL_PEERS
 from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
 from petals.dht_utils import declare_active_modules, get_remote_module_infos

+ 4 - 7
src/petals/server/throughput.py

@@ -9,10 +9,9 @@ from typing import Optional, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from transformers import BloomConfig
 
-from petals.bloom.block import BloomBlock
-from petals.bloom.model import BloomConfig
-from petals.bloom.ops import build_alibi_tensor
+from petals.bloom.block import WrappedBloomBlock
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.convert_8bit import replace_8bit_linear
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@@ -115,10 +114,9 @@ def measure_compute_rps(
     load_in_8bit: bool,
     n_tokens: int = 16,
     n_steps: int = 500,
-    layer_index: int = 0,
 ) -> float:
     with torch.inference_mode():
-        block = BloomBlock(config, layer_index).to(dtype)
+        block = WrappedBloomBlock(config).to(dtype)
         if load_in_8bit:
             block = replace_8bit_linear(block)
         block = block.to(device)
@@ -127,10 +125,9 @@ def measure_compute_rps(
         elapsed = 0
         for step in range(n_steps + 1):
             dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
-            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
 
             start_time = time.perf_counter()
-            _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
+            _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
             if step >= 1:  # Skip the 1st step to exclude the initialization time
                 elapsed += time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed

+ 15 - 0
tests/test_aux_functions.py

@@ -0,0 +1,15 @@
+import pytest
+import torch
+from test_utils import MODEL_NAME
+
+from petals.client import DistributedBloomConfig
+from petals.server.throughput import measure_compute_rps
+
+
+@pytest.mark.forked
+def test_throughput_basic():
+    config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
+    throughput = measure_compute_rps(
+        config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
+    )
+    assert isinstance(throughput, float) and throughput > 0

+ 11 - 3
tests/test_full_model.py

@@ -3,9 +3,9 @@ import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
-from transformers.generation_utils import BeamSearchScorer
+from transformers.generation import BeamSearchScorer
+from transformers.models.bloom import BloomForCausalLM
 
-from petals.bloom.model import BloomForCausalLM
 from petals.client.remote_model import DistributedBloomForCausalLM
 
 use_hivemind_log_handler("in_root_logger")
@@ -13,7 +13,8 @@ logger = get_logger(__file__)
 
 
 @pytest.mark.forked
-def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
+@pytest.mark.parametrize("pass_empty_tensors", (True, False))
+def test_full_model_exact_match(pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@@ -33,8 +34,15 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
         embs = model.transformer.word_embeddings_layernorm(embs)
         recurrent_outputs = []
         with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
+            if pass_empty_tensors:
+                recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
+
             for t in range(embs.shape[1]):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
+                if t == int(embs.shape[1] // 2) and pass_empty_tensors:
+                    recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
+                    recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
+
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
         recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
         recurrent_outputs = model.lm_head(recurrent_outputs)