5
0
Эх сурвалжийг харах

Optimize LLaMA for inference (#513)

* Optimize LLaMa for inference
* Fix model type detection in tests
Max Ryabinin 1 жил өмнө
parent
commit
03cbe90234

+ 209 - 10
src/petals/models/llama/block.py

@@ -3,13 +3,219 @@ LLaMA intermediate layer
 Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
 See commit history for authorship.
 """
+import math
 from typing import Optional, Tuple
 
 import torch
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import (
+    LlamaAttention,
+    LlamaConfig,
+    LlamaDecoderLayer,
+    LlamaMLP,
+    LlamaModel,
+    LlamaRMSNorm,
+    repeat_kv,
+    rotate_half,
+)
 
+from petals.utils.cuda_graphs import make_inference_graphed_callable
+
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class OptimizedLlamaAttention(LlamaAttention):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._rotary_graph = None
+
+    def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
+        if self._rotary_graph is None:
+            self._rotary_graph = make_inference_graphed_callable(
+                apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
+            )
+        return self._rotary_graph(query_states, key_states, cos, sin)
 
-class WrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: bool = False,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+        assert not output_attentions
+        assert position_ids is None
+        bsz, q_len, _ = hidden_states.size()
+
+        if self.config.pretraining_tp > 1:
+            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
+            query_slices = self.q_proj.weight.split(
+                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+            )
+            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
+            query_states = torch.cat(query_states, dim=-1)
+
+            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
+            key_states = torch.cat(key_states, dim=-1)
+
+            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
+            value_states = torch.cat(value_states, dim=-1)
+
+        else:
+            query_states = self.q_proj(hidden_states)
+            key_states = self.k_proj(hidden_states)
+            value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value[0].shape[-2]
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        cos = cos[:, :, kv_seq_len - q_len :]
+        sin = sin[:, :, kv_seq_len - q_len :]
+
+        if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
+        else:
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+        if past_key_value is not None:
+            # reuse k, v, self_attention
+            key_states = torch.cat([past_key_value[0], key_states], dim=2)
+            value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+        past_key_value = (key_states, value_states) if use_cache else None
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        key_states = repeat_kv(key_states, self.num_key_value_groups)
+        value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attention_mask is not None:
+            attn_weights = attn_weights + attention_mask
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+        if self.config.pretraining_tp > 1:
+            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
+        else:
+            attn_output = self.o_proj(attn_output)
+
+        return attn_output, None, past_key_value
+
+
+class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
+    def __init__(self, config: LlamaConfig):
+        nn.Module.__init__(self)
+        self.hidden_size = config.hidden_size
+        self.self_attn = OptimizedLlamaAttention(config=config)
+        self.mlp = LlamaMLP(config)
+        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+        self.pre_attn_graph = None
+        self.post_attn_graph = None
+
+    def _optimized_input_layernorm(self, hidden_states):
+        if self.pre_attn_graph is None:
+            self.pre_attn_graph = make_inference_graphed_callable(
+                self.input_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.pre_attn_graph(hidden_states)
+
+    def _optimized_output_layernorm(self, hidden_states):
+        if self.post_attn_graph is None:
+            self.post_attn_graph = make_inference_graphed_callable(
+                self.post_attention_layernorm.forward, sample_args=(hidden_states,)
+            )
+        return self.post_attn_graph(hidden_states)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        output_attentions: Optional[bool] = False,
+        use_cache: Optional[bool] = False,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        """
+        Args:
+            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+            output_attentions (`bool`, *optional*):
+                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+                returned tensors for more detail.
+            use_cache (`bool`, *optional*):
+                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+                (see `past_key_values`).
+            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+        """
+
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_input_layernorm(hidden_states)
+        else:
+            hidden_states = self.input_layernorm(hidden_states)
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+
+        hidden_states = residual + hidden_states
+
+        # Fully Connected
+        residual = hidden_states
+
+        if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
+            hidden_states = self._optimized_output_layernorm(hidden_states)
+        else:
+            hidden_states = self.post_attention_layernorm(hidden_states)
+
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+
+        outputs = (hidden_states,)
+
+        if output_attentions:
+            outputs += (self_attn_weights,)
+
+        if use_cache:
+            outputs += (present_key_value,)
+
+        return outputs
+
+
+class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
             seq_length_with_past = seq_length_with_past + past_key_values_length
             past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
 
-        if position_ids is None:
-            device = hidden_states.device
-            position_ids = torch.arange(
-                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
-            )
-            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
+        assert position_ids is None
 
         # embed positions
         if attention_mask is None:

+ 76 - 0
src/petals/utils/cuda_graphs.py

@@ -0,0 +1,76 @@
+import torch
+from torch.utils._pytree import tree_flatten as _tree_flatten, tree_unflatten as _tree_unflatten
+
+
+def make_inference_graphed_callable(callable: callable, sample_args, num_warmup_iters=3):
+    """Similar to torch.cuda.make_graphed_callables, but takes only one function and does not build a graph for the backward pass"""
+    assert not isinstance(callable, torch.nn.Module)
+    if torch.is_autocast_enabled() and torch.is_autocast_cache_enabled():
+        raise RuntimeError(
+            "make_graphed_callables does not support the autocast caching. Please set `cache_enabled=False`."
+        )
+
+    flatten_arg, _ = _tree_flatten(sample_args)
+    flatten_sample_args = tuple(flatten_arg)
+    assert all(
+        isinstance(arg, torch.Tensor) for arg in flatten_arg
+    ), "In the beta API, sample_args for each callable must contain only Tensors. Other types are not allowed."
+
+    len_user_args = len(sample_args)
+    static_input_surface = flatten_sample_args
+
+    graph = torch.cuda.CUDAGraph()
+
+    # Warmup
+    # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work
+    # from ending up in any captures.
+    s = torch.cuda.Stream()
+    s.wait_stream(torch.cuda.current_stream())
+    with torch.cuda.stream(s):
+        for _ in range(num_warmup_iters):
+            outputs, _ = _tree_flatten(callable(*sample_args))
+        del outputs
+    torch.cuda.current_stream().wait_stream(s)
+
+    # Capture forward graph
+    with torch.cuda.graph(graph):
+        outputs = callable(*sample_args)
+
+    flatten_outputs, output_unflatten_spec = _tree_flatten(outputs)
+    static_outputs = tuple(flatten_outputs)
+
+    def make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    ):
+        def replay_graph(*inputs):
+            # At this stage, only the user args may (potentially) be new tensors.
+            for i in range(len_user_args):
+                if static_input_surface[i].data_ptr() != inputs[i].data_ptr():
+                    static_input_surface[i].copy_(inputs[i])
+            graph.replay()
+            assert isinstance(static_outputs, tuple)
+            return tuple(o.detach() for o in static_outputs)
+
+        def functionalized(*user_args):
+            # Runs the autograd function with inputs == all inputs to the graph that might require grad
+            # (explicit user args + module parameters)
+            # Assumes module params didn't change since capture.
+            flatten_user_args, _ = _tree_flatten(user_args)
+            out = replay_graph(*flatten_user_args)
+            return _tree_unflatten(out, output_unflatten_spec)
+
+        return functionalized
+
+    # Put together the final graphed callable
+    graphed = make_graphed_function(
+        graph,
+        len_user_args,
+        output_unflatten_spec,
+        static_input_surface,
+        static_outputs,
+    )
+    return graphed

+ 93 - 5
tests/test_optimized_layers.py

@@ -3,6 +3,7 @@ from typing import Optional, Tuple
 import pytest
 import torch
 from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
 
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, convert_block
@@ -94,10 +95,91 @@ class UnoptimizedWrappedFalconBlock(FalconDecoderLayer):
         return state
 
 
-@pytest.mark.skipif("falcon" not in MODEL_NAME, reason="This test is applicable only to Falcon models")
+class UnoptimizedWrappedLlamaBlock(LlamaDecoderLayer):
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        *args,
+        attention_mask: Optional[torch.Tensor] = None,
+        position_ids: Optional[torch.LongTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor]] = None,
+        use_cache: bool = False,
+        **kwargs,
+    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+        batch_size, seq_length, _ = hidden_states.shape
+
+        seq_length_with_past = seq_length
+        past_key_values_length = 0
+
+        past_key_value = layer_past
+        if past_key_value is not None:
+            past_key_values_length = past_key_value[0].shape[2]
+            seq_length_with_past = seq_length_with_past + past_key_values_length
+            past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
+
+        if position_ids is None:
+            device = hidden_states.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+            )
+            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # embed positions
+        if attention_mask is None:
+            attention_mask = torch.ones(
+                (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
+            )
+        attention_mask = LlamaModel._prepare_decoder_attention_mask(
+            None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
+        )
+
+        outputs = super().forward(
+            hidden_states,
+            *args,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            use_cache=use_cache,
+            **kwargs,
+        )
+
+        if use_cache:
+            present_key_value = outputs[-1]
+            present_key_value = self._reorder_cache_from_llama_to_bloom(
+                present_key_value, batch_size, seq_length_with_past
+            )
+            outputs = outputs[:-1] + (present_key_value,)
+
+        return outputs
+
+    def _reorder_cache_from_bloom_to_llama(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        key_states = key_states.permute(0, 2, 1)
+        key_states = key_states.view(
+            batch_size, self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        value_states = value_states.view(*key_states.shape)
+        return (key_states, value_states)
+
+    def _reorder_cache_from_llama_to_bloom(
+        self, key_value: Tuple[torch.Tensor], batch_size: int, seq_length: int
+    ) -> Tuple[torch.Tensor]:
+        key_states, value_states = key_value
+        value_states = value_states.view(
+            batch_size * self.self_attn.num_key_value_heads, seq_length, self.self_attn.head_dim
+        )
+        key_states = key_states.view(*value_states.shape)
+        key_states = key_states.permute(0, 2, 1)
+        return (key_states, value_states)
+
+
 @pytest.mark.parametrize("device", ["cpu", "cuda:0"])
 @pytest.mark.forked
-def test_falcon(device):
+def test_optimized_block(device):
     if device == "cuda:0" and not torch.cuda.is_available():
         pytest.skip("CUDA tests can be run only in CUDA-enabled setups")
 
@@ -108,11 +190,17 @@ def test_falcon(device):
     quant_type = QuantType.NONE
 
     block = config.block_class(config).to(dtype)
-    block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+    block = convert_block(block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
+
+    if config.model_type == "falcon":
+        unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
+    elif config.model_type == "llama":
+        unopt_block = UnoptimizedWrappedLlamaBlock(config).to(dtype)
+    else:
+        pytest.skip(f"This test is not applicable to {config.model_type} models")
 
-    unopt_block = UnoptimizedWrappedFalconBlock(config).to(dtype)
     unopt_block = convert_block(
-        unopt_block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
+        unopt_block, 1, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True
     )
 
     unopt_block.load_state_dict(block.state_dict())