浏览代码

Bump transformers and accelerate versions (#554)

Bump versions for transformers and accelerate, remove falcon-rw-1b CI tests
Denis Mazur 1 年之前
父节点
当前提交
0d91bbdac3

+ 0 - 2
.github/workflows/run-tests.yaml

@@ -14,8 +14,6 @@ jobs:
           - { model: 'bigscience/bloom-560m', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.8' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
-          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.8' }
-          - { model: 'petals-team/falcon-rw-1b', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
       fail-fast: false

+ 2 - 2
setup.cfg

@@ -34,10 +34,10 @@ python_requires = >=3.8
 install_requires =
     torch>=1.12
     bitsandbytes==0.41.1
-    accelerate>=0.22.0
+    accelerate>=0.27.2
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
-    transformers>=4.32.0,<4.35.0  # if you change this, please also change version assert in petals/__init__.py
+    transformers==4.37.1  # if you change this, please also change version assert in petals/__init__.py
     speedtest-cli==2.1.3
     pydantic>=1.10,<2.0  # 2.0 is incompatible with hivemind yet
     hivemind==1.1.10.post2

+ 3 - 3
src/petals/__init__.py

@@ -17,13 +17,13 @@ from petals.models import *
 from petals.utils import *
 from petals.utils.logging import initialize_logs as _initialize_logs
 
-__version__ = "2.3.0.dev1"
+__version__ = "2.3.0.dev2"
 
 
 if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
     assert (
-        version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
-    ), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
+        version.parse("4.37.1") <= version.parse(transformers.__version__) < version.parse("4.38.0")
+    ), "Please install a proper transformers version: pip install transformers>=4.37.1,<4.38.0"
 
 
 def _override_bfloat16_mode_default():

+ 1 - 0
src/petals/client/inference_session.py

@@ -211,6 +211,7 @@ class InferenceSession:
         self._position = 0
         self._max_length = max_length
         self.output_ids = None
+        self.past_key_values = None
 
     @property
     def num_blocks(self) -> int:

+ 26 - 5
src/petals/client/remote_generation.py

@@ -1,11 +1,13 @@
 import contextlib
 import dataclasses
 from contextvars import ContextVar
-from typing import ContextManager, List, Optional
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
 
 import torch
 import transformers
 from hivemind.utils.logging import get_logger
+from torch import Tensor
+from transformers.cache_utils import Cache, DynamicCache
 from transformers.generation.utils import ModelOutput
 
 from petals.client.inference_session import InferenceSession
@@ -15,15 +17,29 @@ from petals.utils.misc import DUMMY, docstring_from
 logger = get_logger(__name__)
 
 
-@dataclasses.dataclass(frozen=True)
-class RemotePastKeyValues:
-    """A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
+class RemotePastKeyValues(Cache):
+    """only keeps the number of seen tokens. pretends to be a legit cache"""
 
-    hypo_ids: Optional[torch.LongTensor] = None
+    def __init__(self) -> None:
+        super().__init__()
+        self.seen_tokens = 0
+        self.hypo_ids: Optional[torch.LongTensor] = None
 
     def __getitem__(self, _index: int) -> List[torch.Tensor]:
         return [DUMMY]  # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
 
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        return self.seen_tokens
+
+    def get_max_length(self) -> Optional[int]:
+        return None
+
+    def update_seen(self, new_seen: int) -> None:
+        self.seen_tokens += new_seen
+
+    def reorder_cache(self, beam_idx):
+        pass
+
 
 _skipped_tokens = ContextVar("skipped_tokens", default=0)
 
@@ -113,6 +129,11 @@ class RemoteGenerationMixin(_SkipTokensMixin):
                 # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
                 _skipped_tokens.set(max(0, n_prev_tokens - 1))
 
+            if self._supports_cache_class and "past_key_values" not in kwargs:
+                past_key_values = RemotePastKeyValues()
+                past_key_values.update_seen(session.position)
+                kwargs["past_key_values"] = past_key_values
+
             result = super().generate(inputs, *args, **kwargs)
 
             sequences = result.sequences if isinstance(result, ModelOutput) else result

+ 8 - 1
src/petals/models/bloom/block.py

@@ -6,6 +6,7 @@ See commit history for authorship.
 from typing import Optional, Tuple
 
 import torch
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
 
 
@@ -26,7 +27,13 @@ class WrappedBloomBlock(BloomBlock):
         attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
         if alibi is None:
             alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
-        attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask=attention_mask,
+            input_shape=(batch_size, seq_length),
+            inputs_embeds=hidden_states,
+            past_key_values_length=past_length,
+        )
+        attention_mask = attention_mask.bool()
         return super().forward(
             hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
         )

+ 59 - 1
src/petals/models/bloom/model.py

@@ -4,6 +4,7 @@ import hivemind
 import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger
+from transformers.cache_utils import Cache
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
 from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassification, BloomModel, BloomPreTrainedModel
 
@@ -92,12 +93,16 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
 
+        if past_key_values is None:
+            past_key_values = RemotePastKeyValues()
+        past_key_values.update_seen(hidden_states.size(1))
+
         # Add last hidden state
         hidden_states = self.ln_f(hidden_states)
         hidden_states = hidden_states.view(output_shape)
         return BaseModelOutputWithPastAndCrossAttentions(
             last_hidden_state=hidden_states,
-            past_key_values=RemotePastKeyValues(),
+            past_key_values=past_key_values,
             hidden_states=None,
             attentions=None,
         )
@@ -107,6 +112,7 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
     _keys_to_ignore_on_load_missing = DistributedBloomModel._keys_to_ignore_on_load_missing
     _keys_to_ignore_on_load_missing += [r"^lm_head\."]  # Missing since they are shared with input embeddings
     _keys_to_ignore_on_load_unexpected = DistributedBloomModel._keys_to_ignore_on_load_unexpected
+    _supports_cache_class = True
 
     config_class = DistributedBloomConfig
 
@@ -118,6 +124,58 @@ class DistributedBloomForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, Bl
         # Initialize weights and apply final processing
         self.post_init()
 
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+    ) -> dict:
+        # Omit tokens covered by past_key_values
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                cache_length = past_key_values.get_seq_length()
+                past_length = past_key_values.seen_tokens
+                max_cache_length = past_key_values.get_max_length()
+            else:
+                cache_length = past_length = past_key_values[0][0].shape[2]
+                max_cache_length = None
+
+            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+            elif past_length < input_ids.shape[1]:
+                input_ids = input_ids[:, past_length:]
+
+            if (
+                max_cache_length is not None
+                and attention_mask is not None
+                and cache_length + input_ids.shape[1] > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -input_ids.shape[1] :]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs
+
+    def _temporary_reorder_cache(self, past_key_values, beam_idx):
+        return past_key_values
+
     def get_output_embeddings(self):
         return self.lm_head
 

+ 8 - 4
src/petals/models/llama/block.py

@@ -9,6 +9,7 @@ from typing import Optional, Tuple
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.models.llama.modeling_llama import (
     LlamaAttention,
     LlamaConfig,
@@ -84,8 +85,8 @@ class OptimizedLlamaAttention(LlamaAttention):
         if past_key_value is not None:
             kv_seq_len += past_key_value[0].shape[-2]
         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        cos = cos[:, :, kv_seq_len - q_len :]
-        sin = sin[:, :, kv_seq_len - q_len :]
+        cos = cos[kv_seq_len - q_len :]
+        sin = sin[kv_seq_len - q_len :]
 
         if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
             query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
@@ -244,8 +245,11 @@ class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
             attention_mask = torch.ones(
                 (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
             )
-        attention_mask = LlamaModel._prepare_decoder_attention_mask(
-            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask=attention_mask,
+            input_shape=(batch_size, seq_length),
+            inputs_embeds=hidden_states,
+            past_key_values_length=past_key_values_length,
         )
 
         outputs = super().forward(

+ 6 - 1
src/petals/models/llama/model.py

@@ -90,6 +90,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
             hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
         )
 
+        if past_key_values is None:
+            past_key_values = RemotePastKeyValues()
+        past_key_values.update_seen(hidden_states.size(1))
+
         # Remove prefix
         if use_prompts:
             hidden_states = hidden_states[:, self.pre_seq_len :]
@@ -97,9 +101,10 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         # Add last hidden state
         hidden_states = self.norm(hidden_states)
         hidden_states = hidden_states.view(output_shape)
+
         return BaseModelOutputWithPast(
             last_hidden_state=hidden_states,
-            past_key_values=RemotePastKeyValues(),
+            past_key_values=past_key_values,
             hidden_states=None,
             attentions=None,
         )

+ 1 - 3
src/petals/utils/peft.py

@@ -26,9 +26,7 @@ logger = get_logger(__name__)
 
 
 def check_peft_repository(repo_id: str) -> bool:
-    fs = HfFileSystem()
-    list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
-    return len(list_of_files) > 0
+    return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
 
 
 def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):

+ 13 - 7
tests/test_optimized_layers.py

@@ -2,6 +2,8 @@ from typing import Optional, Tuple
 
 import pytest
 import torch
+from transformers.cache_utils import DynamicCache
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
 
@@ -116,6 +118,8 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
             past_key_values_length = past_key_value[0].shape[2]
             seq_length_with_past = seq_length_with_past + past_key_values_length
             past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+        elif use_cache:
+            past_key_value = DynamicCache()
 
         if position_ids is None:
             device = hidden_states.device
@@ -131,8 +135,9 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
             attention_mask = torch.ones(
                 (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
             )
-        attention_mask = LlamaModel._prepare_decoder_attention_mask(
-            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
         )
 
         outputs = super().forward(
@@ -156,19 +161,20 @@ class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
 
     def _reorder_cache_from_bloom_to_llama(
         self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
-    ) -> Tuple[torch.Tensor]:
+    ) -> DynamicCache:
         key_states, value_states = key_value
         key_states = key_states.permute(0, 2, 1)
         key_states = key_states.view(
             batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
         )
         value_states = value_states.view(*key_states.shape)
-        return (key_states, value_states)
+        past_key_values = ((key_states, value_states),)
+        return DynamicCache.from_legacy_cache(past_key_values)
 
     def _reorder_cache_from_llama_to_bloom(
-        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+        self, key_value: DynamicCache, batch_size: int, seq_length: int
     ) -> Tuple[torch.Tensor]:
-        key_states, value_states = key_value
+        key_states, value_states = key_value.to_legacy_cache()[0]
         value_states = value_states.view(
             batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
         )
@@ -195,7 +201,7 @@ def test_optimized_block(device):
     if config.model_type == "falcon":
         unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
     elif config.model_type == "llama":
-        unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
+        unopt_block = UnoptimizedWrappedLlamaBlock(config, layer_idx=0).to(dtype)
     else:
         pytest.skip(f"This test is not applicable to {config.model_type} models")