Преглед на файлове

Fix Mixtral-related issues (#570)

This PR fixes problems related to #569:
- block initialization
- throughput calculation and cache usage
- mixtral in tests

Beam search is removed for Mixtral and Llama for now. Those models use DynamicCache, which requires special function to change: (see https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py#L161)

---------

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Artem Chumachenko преди 1 година
родител
ревизия
d6f4f80f3f

+ 2 - 0
.github/workflows/run-tests.yaml

@@ -16,6 +16,8 @@ jobs:
           - { model: 'Maykeye/TinyLLama-v0', os: 'ubuntu', python-version: '3.11' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.10' }
           - { model: 'Maykeye/TinyLLama-v0', os: 'macos', python-version: '3.11' }
+          - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.8' }
+          - { model: 'artek0chumak/TestMixtral', os: 'ubuntu', python-version: '3.11' }
       fail-fast: false
     runs-on: ${{ matrix.os }}-latest
     timeout-minutes: 20

+ 1 - 1
src/petals/client/remote_generation.py

@@ -38,7 +38,7 @@ class RemotePastKeyValues(Cache):
         self.seen_tokens += new_seen
 
     def reorder_cache(self, beam_idx):
-        pass
+        raise NotImplementedError("Beam search reordering is not implemented yet")
 
 
 _skipped_tokens = ContextVar("skipped_tokens", default=0)

+ 6 - 0
src/petals/models/bloom/block.py

@@ -9,6 +9,8 @@ import torch
 from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor
 
+from petals.utils.misc import is_dummy
+
 
 class WrappedBloomBlock(BloomBlock):
     def forward(
@@ -22,6 +24,10 @@ class WrappedBloomBlock(BloomBlock):
     ):
         assert attention_mask is None, "Non-causal attention masks are not supported yet"
         batch_size, seq_length = hidden_states.shape[:2]
+        if layer_past is not None and is_dummy(layer_past[0]):
+            # Bloom cannot use cache if it was misconsctructed(e.g. Dummy tensors)
+            # In this case, fallback to the old code:
+            layer_past = None
         past_length = 0 if layer_past is None else layer_past[0].shape[-1]
         seq_length_with_past = seq_length + past_length
         attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)

+ 6 - 6
src/petals/models/mixtral/block.py

@@ -1,3 +1,4 @@
+import json
 from typing import Optional, Tuple
 
 import torch
@@ -33,16 +34,15 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
         past_key_values_length = 0
 
         past_key_value = layer_past
+
         if past_key_value is not None:
             past_key_values_length = past_key_value[0].shape[2]
             seq_length_with_past = seq_length_with_past + past_key_values_length
             _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
             past_key_value = DynamicCache()
-            for idx in range(self.layer_idx):
-                past_key_value.update(
-                    torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
-                )
-            past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
+            past_key_value.key_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[0]]
+            past_key_value.value_cache = [torch.empty(0) for _ in range(self.layer_idx)] + [_past_key_value[1]]
+            past_key_value._seen_tokens = past_key_values_length
 
         if self._attn_implementation == "flash_attention_2":
             # 2d mask is passed through the layers
@@ -83,7 +83,7 @@ class WrappedMixtralBlock(MixtralDecoderLayer):
 
         if use_cache:
             present_key_value = outputs[-1]
-            present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
+            present_key_value = present_key_value[self.layer_idx]
             present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
             outputs = outputs[:-1] + (present_key_value,)
 

+ 15 - 6
src/petals/models/mixtral/model.py

@@ -122,14 +122,20 @@ class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMi
     def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin
         return self.embed_tokens
 
+    @property
+    def word_embeddings_layernorm(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin in tests
+        return nn.Identity()
+
     @property
     def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin
         return self.layers
 
+    @property
+    def ln_f(self) -> nn.Module:  # For compatibility with RemoteGenerationMixin in tests
+        return self.norm
 
-class DistributedMixtralForCausalLM(
-    DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
-):
+
+class DistributedMixtralForCausalLM(FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM):
     _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
     _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
 
@@ -151,9 +157,12 @@ class DistributedMixtralForCausalLM(
         return self.model
 
 
-class DistributedMixtralForSequenceClassification(
-    DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
-):
+class DistributedMixtralForSequenceClassification(FromPretrainedMixin, MixtralForSequenceClassification):
+    _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedMixtralConfig
+
     def __init__(self, config: DistributedMixtralConfig):
         MixtralPreTrainedModel.__init__(self, config)
         self.num_labels = config.num_labels

+ 15 - 2
src/petals/server/block_utils.py

@@ -2,8 +2,9 @@ from typing import Optional, Union
 
 import torch
 from accelerate import init_empty_weights
-from transformers import PretrainedConfig
+from transformers import PretrainedConfig, PreTrainedModel
 
+from petals.models.mixtral.block import WrappedMixtralBlock
 from petals.utils.convert_block import QuantType
 from petals.utils.misc import get_size_in_bytes
 
@@ -32,7 +33,7 @@ def get_block_size(
         ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
 
     with init_empty_weights(include_buffers=True):
-        block = config.block_class(config)
+        block = get_model_block(config)
         n_params = sum(param.numel() for param in block.parameters())
 
     if location == "memory":
@@ -50,3 +51,15 @@ def get_block_size(
         bytes_per_value = get_size_in_bytes(dtype)
 
     return round(n_params * bytes_per_value * (1 + eps))
+
+
+def get_model_block(config, layer_idx: int = 0):
+    """
+    The function to create a model block based on the block class
+    kwargs argument **only** is necessary for specific classes, like Mixtral.
+    They will not be passed to other block constructors.
+    """
+    if config.block_class == WrappedMixtralBlock:
+        config = PreTrainedModel._autoset_attn_implementation(config)
+        return config.block_class(config, layer_idx)
+    return config.block_class(config)

+ 2 - 6
src/petals/server/from_pretrained.py

@@ -24,7 +24,7 @@ from transformers.utils import get_file_from_repo
 
 from petals.constants import DTYPE_MAP
 from petals.models.mixtral import WrappedMixtralBlock
-from petals.server.block_utils import resolve_block_dtype
+from petals.server.block_utils import get_model_block, resolve_block_dtype
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.hf_auth import always_needs_auth
@@ -52,11 +52,7 @@ def load_pretrained_block(
     torch_dtype = resolve_block_dtype(config, torch_dtype)
 
     with init_empty_weights():
-        if config.block_class == WrappedMixtralBlock:
-            config = PreTrainedModel._autoset_attn_implementation(config)
-            block = config.block_class(config, block_index)
-        else:
-            block = config.block_class(config)
+        block = get_model_block(config, layer_idx=block_index)
 
     block_prefix = f"{config.block_prefix}.{block_index}."
     state_dict = _load_state_dict_from_repo(

+ 13 - 5
src/petals/server/throughput.py

@@ -13,9 +13,10 @@ import torch.mps
 from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
-from petals.server.block_utils import resolve_block_dtype
+from petals.server.block_utils import get_model_block, resolve_block_dtype
 from petals.utils.convert_block import QuantType, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
+from petals.utils.misc import DUMMY_KEY_PAST
 
 logger = get_logger(__name__)
 
@@ -201,18 +202,25 @@ def measure_compute_rps(
     if not tensor_parallel_devices:
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
-        block = config.block_class(config).to(dtype)
+        block = get_model_block(config)
+        block = block.to(dtype)
         block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
 
-        cache = None
+        cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
         elapsed = 0
         dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
-        _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
+
+        # Skip the 1st step to exclude the initialization time
+        def step(cache_):
+            outputs = block.forward(dummy_input, use_cache=inference, layer_past=cache_ if inference else None)
+            return outputs[1] if inference else None
+
+        cache = step(cache)
         synchronize(device)
 
         start_time = time.perf_counter()
         for _ in range(n_steps):
-            _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
+            cache = step(cache)
         synchronize(device)
         elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed

+ 2 - 0
src/petals/utils/misc.py

@@ -4,6 +4,8 @@ DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter par
 
 DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
 
+DUMMY_KEY_PAST = torch.empty((0, 0, 0))
+
 
 def is_dummy(tensor: torch.Tensor) -> bool:
     return tensor.numel() == 0

+ 2 - 2
src/petals/utils/peft.py

@@ -17,7 +17,7 @@ from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 
-from petals.server.block_utils import resolve_block_dtype
+from petals.server.block_utils import get_model_block, resolve_block_dtype
 from petals.utils.convert_block import QuantType
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
 from petals.utils.misc import get_size_in_bytes
@@ -273,7 +273,7 @@ def estimate_adapter_memory_per_block(
 ) -> int:
     """Get the number of extra bytes used to store a set of adapters per given block"""
     with init_empty_weights(include_buffers=True):
-        block = block_config.block_class(block_config)
+        block = get_model_block(block_config)
         base_block_parameters = sum(p.numel() for p in block.parameters())
         create_lora_adapter(block, quant_type=QuantType.NONE)
 

+ 6 - 3
tests/test_chained_calls.py

@@ -10,6 +10,7 @@ import torch
 from petals import AutoDistributedConfig
 from petals.client.remote_sequential import RemoteSequential
 from petals.server.from_pretrained import load_pretrained_block
+from petals.utils.misc import DUMMY_KEY_PAST
 from test_utils import *
 
 
@@ -54,12 +55,14 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
             outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
     outputs_inference = torch.cat(outputs_inference, dim=1)
 
+    dtype = torch.float32
     ref_blocks = [
-        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=dtype),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=dtype),
     ]
     outputs_ref = []
-    caches = [None, None]
+    cache = (DUMMY_KEY_PAST.to(dtype), DUMMY_KEY_PAST.to(dtype))
+    caches = [cache, cache]
     for i in range(inputs.shape[1]):
         new_caches = []
         hidden_states = inputs[:, i : i + 1, :]

+ 4 - 0
tests/test_full_model.py

@@ -141,6 +141,10 @@ def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
                 ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
 
 
+@pytest.mark.skipif(
+    "bloom" not in MODEL_NAME.lower(),
+    reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
+)
 @pytest.mark.forked
 def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
     inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

+ 5 - 3
tests/test_optimized_layers.py

@@ -7,6 +7,7 @@ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_m
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
 
+from petals.server.block_utils import get_model_block
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, convert_block
 from test_utils import MODEL_NAME
@@ -195,8 +196,9 @@ def test_optimized_block(device):
     dtype = torch.bfloat16
     quant_type = QuantType.NONE
 
-    block = config.block_class(config).to(dtype)
-    block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+    block_idx = 1
+    block = get_model_block(config, layer_idx=block_idx).to(dtype)
+    block = convert_block(block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
 
     if config.model_type == "falcon":
         unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
@@ -206,7 +208,7 @@ def test_optimized_block(device):
         pytest.skip(f"This test is not applicable to {config.model_type} models")
 
     unopt_block = convert_block(
-        unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
+        unopt_block, block_idx, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
     )
 
     unopt_block.load_state_dict(block.state_dict())