|
@@ -3,13 +3,219 @@ LLaMA intermediate layer
|
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
|
Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
|
See commit history for authorship.
|
|
See commit history for authorship.
|
|
"""
|
|
"""
|
|
|
|
+import math
|
|
from typing import Optional, Tuple
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
import torch
|
|
import torch
|
|
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
+from transformers.models.llama.modeling_llama import (
|
|
|
|
+ LlamaAttention,
|
|
|
|
+ LlamaConfig,
|
|
|
|
+ LlamaDecoderLayer,
|
|
|
|
+ LlamaMLP,
|
|
|
|
+ LlamaModel,
|
|
|
|
+ LlamaRMSNorm,
|
|
|
|
+ repeat_kv,
|
|
|
|
+ rotate_half,
|
|
|
|
+)
|
|
|
|
|
|
|
|
+from petals.utils.cuda_graphs import make_inference_graphed_callable
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def apply_rotary_pos_emb(q, k, cos, sin):
|
|
|
|
+ q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
|
|
+ k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
|
|
+ return q_embed, k_embed
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class OptimizedLlamaAttention(LlamaAttention):
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
|
+ self._rotary_graph = None
|
|
|
|
+
|
|
|
|
+ def _optimized_apply_rotary(self, query_states, key_states, cos, sin):
|
|
|
|
+ if self._rotary_graph is None:
|
|
|
|
+ self._rotary_graph = make_inference_graphed_callable(
|
|
|
|
+ apply_rotary_pos_emb, sample_args=(query_states, key_states, cos, sin)
|
|
|
|
+ )
|
|
|
|
+ return self._rotary_graph(query_states, key_states, cos, sin)
|
|
|
|
|
|
-class WrappedLlamaBlock(LlamaDecoderLayer):
|
|
|
|
|
|
+ def forward(
|
|
|
|
+ self,
|
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+ position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
+ output_attentions: bool = False,
|
|
|
|
+ use_cache: bool = False,
|
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
+ assert not output_attentions
|
|
|
|
+ assert position_ids is None
|
|
|
|
+ bsz, q_len, _ = hidden_states.size()
|
|
|
|
+
|
|
|
|
+ if self.config.pretraining_tp > 1:
|
|
|
|
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
|
|
+ query_slices = self.q_proj.weight.split(
|
|
|
|
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
|
|
+ )
|
|
|
|
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
+
|
|
|
|
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
|
+ query_states = torch.cat(query_states, dim=-1)
|
|
|
|
+
|
|
|
|
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
|
+ key_states = torch.cat(key_states, dim=-1)
|
|
|
|
+
|
|
|
|
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
|
+ value_states = torch.cat(value_states, dim=-1)
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ query_states = self.q_proj(hidden_states)
|
|
|
|
+ key_states = self.k_proj(hidden_states)
|
|
|
|
+ value_states = self.v_proj(hidden_states)
|
|
|
|
+
|
|
|
|
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+
|
|
|
|
+ kv_seq_len = key_states.shape[-2]
|
|
|
|
+ if past_key_value is not None:
|
|
|
|
+ kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
+ cos = cos[:, :, kv_seq_len - q_len :]
|
|
|
|
+ sin = sin[:, :, kv_seq_len - q_len :]
|
|
|
|
+
|
|
|
|
+ if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
+ query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
|
|
|
|
+ else:
|
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
+
|
|
|
|
+ if past_key_value is not None:
|
|
|
|
+ # reuse k, v, self_attention
|
|
|
|
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
|
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
+
|
|
|
|
+ past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
+
|
|
|
|
+ # repeat k/v heads if n_kv_heads < n_heads
|
|
|
|
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
|
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
+
|
|
|
|
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
+
|
|
|
|
+ if attention_mask is not None:
|
|
|
|
+ attn_weights = attn_weights + attention_mask
|
|
|
|
+
|
|
|
|
+ # upcast attention to fp32
|
|
|
|
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
+ attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
+
|
|
|
|
+ attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
+
|
|
|
|
+ if self.config.pretraining_tp > 1:
|
|
|
|
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
|
|
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
|
|
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
|
|
+ else:
|
|
|
|
+ attn_output = self.o_proj(attn_output)
|
|
|
|
+
|
|
|
|
+ return attn_output, None, past_key_value
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class OptimizedLlamaDecoderLayer(LlamaDecoderLayer):
|
|
|
|
+ def __init__(self, config: LlamaConfig):
|
|
|
|
+ nn.Module.__init__(self)
|
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
|
+ self.self_attn = OptimizedLlamaAttention(config=config)
|
|
|
|
+ self.mlp = LlamaMLP(config)
|
|
|
|
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
+
|
|
|
|
+ self.pre_attn_graph = None
|
|
|
|
+ self.post_attn_graph = None
|
|
|
|
+
|
|
|
|
+ def _optimized_input_layernorm(self, hidden_states):
|
|
|
|
+ if self.pre_attn_graph is None:
|
|
|
|
+ self.pre_attn_graph = make_inference_graphed_callable(
|
|
|
|
+ self.input_layernorm.forward, sample_args=(hidden_states,)
|
|
|
|
+ )
|
|
|
|
+ return self.pre_attn_graph(hidden_states)
|
|
|
|
+
|
|
|
|
+ def _optimized_output_layernorm(self, hidden_states):
|
|
|
|
+ if self.post_attn_graph is None:
|
|
|
|
+ self.post_attn_graph = make_inference_graphed_callable(
|
|
|
|
+ self.post_attention_layernorm.forward, sample_args=(hidden_states,)
|
|
|
|
+ )
|
|
|
|
+ return self.post_attn_graph(hidden_states)
|
|
|
|
+
|
|
|
|
+ def forward(
|
|
|
|
+ self,
|
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+ position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
+ output_attentions: Optional[bool] = False,
|
|
|
|
+ use_cache: Optional[bool] = False,
|
|
|
|
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
|
|
|
+ """
|
|
|
|
+ Args:
|
|
|
|
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
|
|
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
|
|
|
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
|
|
+ output_attentions (`bool`, *optional*):
|
|
|
|
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
|
|
+ returned tensors for more detail.
|
|
|
|
+ use_cache (`bool`, *optional*):
|
|
|
|
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
|
|
|
+ (see `past_key_values`).
|
|
|
|
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ residual = hidden_states
|
|
|
|
+
|
|
|
|
+ if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
+ hidden_states = self._optimized_input_layernorm(hidden_states)
|
|
|
|
+ else:
|
|
|
|
+ hidden_states = self.input_layernorm(hidden_states)
|
|
|
|
+
|
|
|
|
+ # Self Attention
|
|
|
|
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
|
|
|
+ hidden_states=hidden_states,
|
|
|
|
+ attention_mask=attention_mask,
|
|
|
|
+ position_ids=position_ids,
|
|
|
|
+ past_key_value=past_key_value,
|
|
|
|
+ output_attentions=output_attentions,
|
|
|
|
+ use_cache=use_cache,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ hidden_states = residual + hidden_states
|
|
|
|
+
|
|
|
|
+ # Fully Connected
|
|
|
|
+ residual = hidden_states
|
|
|
|
+
|
|
|
|
+ if hidden_states.size(1) == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|
|
|
|
+ hidden_states = self._optimized_output_layernorm(hidden_states)
|
|
|
|
+ else:
|
|
|
|
+ hidden_states = self.post_attention_layernorm(hidden_states)
|
|
|
|
+
|
|
|
|
+ hidden_states = self.mlp(hidden_states)
|
|
|
|
+ hidden_states = residual + hidden_states
|
|
|
|
+
|
|
|
|
+ outputs = (hidden_states,)
|
|
|
|
+
|
|
|
|
+ if output_attentions:
|
|
|
|
+ outputs += (self_attn_weights,)
|
|
|
|
+
|
|
|
|
+ if use_cache:
|
|
|
|
+ outputs += (present_key_value,)
|
|
|
|
+
|
|
|
|
+ return outputs
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class WrappedLlamaBlock(OptimizedLlamaDecoderLayer):
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
@@ -31,14 +237,7 @@ class WrappedLlamaBlock(LlamaDecoderLayer):
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
|
past_key_value = self._reorder_cache_from_bloom_to_llama(past_key_value, batch_size, past_key_values_length)
|
|
|
|
|
|
- if position_ids is None:
|
|
|
|
- device = hidden_states.device
|
|
|
|
- position_ids = torch.arange(
|
|
|
|
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
|
|
|
- )
|
|
|
|
- position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
|
|
|
- else:
|
|
|
|
- position_ids = position_ids.view(-1, seq_length).long()
|
|
|
|
|
|
+ assert position_ids is None
|
|
|
|
|
|
# embed positions
|
|
# embed positions
|
|
if attention_mask is None:
|
|
if attention_mask is None:
|