Explorar el Código

Add Mixtral models (#553)

* Add somehow workable version

* Fix generation

* Fixes

* Choose right attn

* style

* fix bloom

* remove unnes

* Update src/petals/models/mixtral/model.py

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* fix order of init

---------

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Artem Chumachenko hace 1 año
padre
commit
d2fcbbc72e

+ 1 - 0
src/petals/models/__init__.py

@@ -1,3 +1,4 @@
 from petals.models.bloom import *
 from petals.models.falcon import *
 from petals.models.llama import *
+from petals.models.mixtral import *

+ 15 - 0
src/petals/models/mixtral/__init__.py

@@ -0,0 +1,15 @@
+from petals.models.mixtral.block import WrappedMixtralBlock
+from petals.models.mixtral.config import DistributedMixtralConfig
+from petals.models.mixtral.model import (
+    DistributedMixtralForCausalLM,
+    DistributedMixtralForSequenceClassification,
+    DistributedMixtralModel,
+)
+from petals.utils.auto_config import register_model_classes
+
+register_model_classes(
+    config=DistributedMixtralConfig,
+    model=DistributedMixtralModel,
+    model_for_causal_lm=DistributedMixtralForCausalLM,
+    model_for_sequence_classification=DistributedMixtralForSequenceClassification,
+)

+ 114 - 0
src/petals/models/mixtral/block.py

@@ -0,0 +1,114 @@
+from typing import Optional, Tuple
+
+import torch
+from transformers import MixtralConfig
+from transformers.cache_utils import DynamicCache
+from transformers.modeling_attn_mask_utils import (
+    _prepare_4d_causal_attention_mask,
+    _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralModel
+
+
+class WrappedMixtralBlock(MixtralDecoderLayer):
+    def __init__(self, config: MixtralConfig, layer_idx: int):
+        super().__init__(config, layer_idx)
+
+        self._attn_implementation = config._attn_implementation
+        self.sliding_window = config.sliding_window
+        self.layer_idx = layer_idx
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs
+    ):
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            _past_key_value = self._reorder_cache_from_bloom(past_key_value, batch_size, past_key_values_length)
+            past_key_value = DynamicCache()
+            for idx in range(self.layer_idx):
+                past_key_value.update(
+                    torch.empty(_past_key_value[0].size()), torch.empty(_past_key_value[1].size()), idx
+                )
+            past_key_value.update(_past_key_value[0], _past_key_value[1], self.layer_idx)
+
+        if self._attn_implementation == "flash_attention_2":
+            # 2d mask is passed through the layers
+            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+        elif self._attn_implementation == "sdpa":
+            # output_attentions=True can not be supported when using SDPA, and we fall back on
+            # the manual implementation that requires a 4D causal mask in all cases.
+            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+                attention_mask,
+                (batch_size, seq_length),
+                hidden_states,
+                past_key_values_length,
+            )
+        else:
+            # 4d mask is passed through the layers
+            attention_mask = _prepare_4d_causal_attention_mask(
+                attention_mask,
+                (batch_size, seq_length),
+                hidden_states,
+                past_key_values_length,
+                sliding_window=self.sliding_window,
+            )
+
+        position_ids = torch.arange(
+            past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=hidden_states.device
+        )
+        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = present_key_value.to_legacy_cache()[self.layer_idx]
+            present_key_value = self._reorder_cache_to_bloom(present_key_value, batch_size, seq_length_with_past)
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        # TODO: Move to mixin
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(
+            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        # TODO: Move to mixin
+        key_states, value_states = key_value
+        value_states = value_states.view(
+            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)

+ 36 - 0
src/petals/models/mixtral/config.py

@@ -0,0 +1,36 @@
+import os
+from typing import Optional, Union
+
+from hivemind import get_logger
+from transformers.models.mixtral import MixtralConfig
+from transformers.models.mixtral.modeling_mixtral import MixtralAttention
+
+from petals.client.config import ClientConfig
+from petals.client.lm_head import LMHeadConfig
+from petals.client.ptune import PTuneConfig
+from petals.models.mixtral.block import WrappedMixtralBlock
+
+logger = get_logger(__name__)
+
+
+class DistributedMixtralConfig(MixtralConfig, ClientConfig, PTuneConfig, LMHeadConfig):
+    block_class = WrappedMixtralBlock
+    attn_class = MixtralAttention
+    block_prefix = "model.layers"
+
+    num_key_value_groups = 1
+
+    @classmethod
+    def from_pretrained(
+        cls, model_name_or_path: Union[str, os.PathLike, None], *args, dht_prefix: Optional[str] = None, **kwargs
+    ):
+        loading_from_repo = model_name_or_path is not None and not os.path.isdir(model_name_or_path)
+        if loading_from_repo and dht_prefix is None:
+            dht_prefix = str(model_name_or_path)
+            dht_prefix = dht_prefix.replace(".", "-")
+            logger.info(f"Using DHT prefix: {dht_prefix}")
+        result = super().from_pretrained(model_name_or_path, *args, dht_prefix=dht_prefix, **kwargs)
+        config = result[0] if isinstance(result, tuple) else result
+        if config.pad_token_id is None:
+            config.pad_token_id = 0
+        return result

+ 169 - 0
src/petals/models/mixtral/model.py

@@ -0,0 +1,169 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from hivemind import DHT
+from hivemind.utils.logging import get_logger
+from transformers.modeling_outputs import MoeModelOutputWithPast
+from transformers.models.mixtral import (
+    MixtralForCausalLM,
+    MixtralForSequenceClassification,
+    MixtralModel,
+    MixtralPreTrainedModel,
+)
+
+from petals.client.from_pretrained import FromPretrainedMixin
+from petals.client.lm_head import LMHead
+from petals.client.ptune import PTuneMixin
+from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
+from petals.client.remote_sequential import RemoteSequential
+from petals.models.mixtral.config import DistributedMixtralConfig
+from petals.utils.auto_config import DefaultRevisionMixin
+
+logger = get_logger(__name__)
+
+
+class DistributedMixtralModel(DefaultRevisionMixin, FromPretrainedMixin, PTuneMixin, MixtralModel):
+    """MixtralModel, but all transformer layers are hosted by the swarm"""
+
+    _keys_to_ignore_on_load_missing = PTuneMixin._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = [r"^model\.layers\."]
+
+    config_class = DistributedMixtralConfig
+
+    def __init__(self, config: DistributedMixtralConfig, *, dht: Optional[DHT] = None):
+        n_layer, config.num_hidden_layers = config.num_hidden_layers, 0  # Prevent initialization
+        super().__init__(config)
+        assert len(self.layers) == 0
+        config.num_hidden_layers = n_layer
+
+        self.layers = RemoteSequential(config, dht=dht)
+
+        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.init_prompts(config)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[RemotePastKeyValues] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        head_mask: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        output_router_logits: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+    ):
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+        elif input_ids is not None:
+            input_shape = input_ids.size()
+            input_ids = input_ids.view(-1, input_shape[-1])
+        elif inputs_embeds is not None:
+            input_shape = inputs_embeds.size()[:-1]
+        else:
+            raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+        # The causal mask will be added on the server-side
+        assert (
+            attention_mask is None or (attention_mask == 1).all()
+        ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert (
+            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
+        ), f"Non-consecutive position_ids are not supported, {position_ids=}"
+        assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
+        assert not output_attentions, f"{output_attentions=} is not supported"
+        assert not output_hidden_states, f"{output_hidden_states=} is not supported"
+        assert return_dict is None or return_dict, f"{return_dict=} is not supported"
+        assert not output_router_logits, f"{output_router_logits=} is not supported"
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_tokens(input_ids)
+
+        use_prompts = self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0
+        if use_prompts:
+            batch_size = inputs_embeds.shape[0]
+            prompts, intermediate_prompts = self.get_prompt(batch_size)
+            inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        else:
+            prompts = intermediate_prompts = None
+
+        hidden_states = inputs_embeds
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        if past_key_values is None:
+            past_key_values = RemotePastKeyValues()
+        past_key_values.update_seen(hidden_states.size(1))
+
+        hidden_states = self.layers(
+            hidden_states,
+            prompts=intermediate_prompts,
+            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
+        )
+
+        # Remove prefix
+        if use_prompts:
+            hidden_states = hidden_states[:, self.pre_seq_len :]
+
+        # Add last hidden state
+        hidden_states = self.norm(hidden_states)
+        hidden_states = hidden_states.view(output_shape)
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=past_key_values,
+            hidden_states=None,
+            attentions=None,
+        )
+
+    @property
+    def word_embeddings(self) -> nn.Embedding:  # For compatibility with RemoteGenerationMixin
+        return self.embed_tokens
+
+    @property
+    def h(self) -> RemoteSequential:  # For compatibility with RemoteGenerationMixin
+        return self.layers
+
+
+class DistributedMixtralForCausalLM(
+    DefaultRevisionMixin, FromPretrainedMixin, RemoteGenerationMixin, MixtralForCausalLM
+):
+    _keys_to_ignore_on_load_missing = DistributedMixtralModel._keys_to_ignore_on_load_missing
+    _keys_to_ignore_on_load_unexpected = DistributedMixtralModel._keys_to_ignore_on_load_unexpected
+
+    config_class = DistributedMixtralConfig
+
+    def __init__(self, config: DistributedMixtralConfig):
+        MixtralPreTrainedModel.__init__(self, config)
+        self.model = DistributedMixtralModel(config)
+        self.lm_head = LMHead(config)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    @property
+    def transformer(self) -> DistributedMixtralModel:  # For compatibility with RemoteGenerationMixin
+        return self.model
+
+
+class DistributedMixtralForSequenceClassification(
+    DefaultRevisionMixin, FromPretrainedMixin, MixtralForSequenceClassification
+):
+    def __init__(self, config: DistributedMixtralConfig):
+        MixtralPreTrainedModel.__init__(self, config)
+        self.num_labels = config.num_labels
+
+        self.model = DistributedMixtralModel(config)
+        self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
+
+        # Initialize weights and apply final processing
+        self.post_init()
+
+    @property
+    def transformer(self) -> DistributedMixtralModel:  # For compatibility with RemoteGenerationMixin
+        return self.model

+ 2 - 0
src/petals/server/backend.py

@@ -91,6 +91,8 @@ class TransformerBackend(ModuleBackend):
         cache_tensors = []
         for device, num_heads in zip(self.module.devices, self.shard_num_heads):
             num_heads //= self.config.num_key_value_groups
+            if hasattr(self.config, "num_key_value_heads"):
+                num_heads = self.config.num_key_value_heads
             keys = TensorDescriptor((batch_size, num_heads, head_dim, max_length), dtype=self.dtype, device=device)
             values = TensorDescriptor((batch_size, num_heads, max_length, head_dim), dtype=self.dtype, device=device)
             cache_tensors.extend((keys, values))

+ 7 - 2
src/petals/server/from_pretrained.py

@@ -19,10 +19,11 @@ from accelerate.utils import set_module_tensor_to_device
 from hivemind.utils.logging import get_logger
 from huggingface_hub import get_hf_file_metadata, hf_hub_url
 from huggingface_hub.utils import EntryNotFoundError
-from transformers import PretrainedConfig
+from transformers import PretrainedConfig, PreTrainedModel
 from transformers.utils import get_file_from_repo
 
 from petals.constants import DTYPE_MAP
+from petals.models.mixtral import WrappedMixtralBlock
 from petals.server.block_utils import resolve_block_dtype
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR, allow_cache_reads, allow_cache_writes, free_disk_space_for
@@ -51,7 +52,11 @@ def load_pretrained_block(
     torch_dtype = resolve_block_dtype(config, torch_dtype)
 
     with init_empty_weights():
-        block = config.block_class(config)
+        if config.block_class == WrappedMixtralBlock:
+            config = PreTrainedModel._autoset_attn_implementation(config)
+            block = config.block_class(config, block_index)
+        else:
+            block = config.block_class(config)
 
     block_prefix = f"{config.block_prefix}.{block_index}."
     state_dict = _load_state_dict_from_repo(