|
@@ -87,10 +87,7 @@ class OptimizedLlamaAttention(LlamaAttention):
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
- kv_seq_len = key_states.shape[-2]
|
|
|
- if past_key_value is not None:
|
|
|
- kv_seq_len += past_key_value[0].shape[-2]
|
|
|
- cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
|
|
|
+ cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
cos, sin = cos.unsqueeze(1), sin.unsqueeze(1)
|
|
|
|
|
|
if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
|