瀏覽代碼

account for layer_past in alibi

justheuristic 3 年之前
父節點
當前提交
3ccd0b5e2d
共有 1 個文件被更改,包括 3 次插入0 次删除
  1. 3 0
      src/block.py

+ 3 - 0
src/block.py

@@ -74,6 +74,9 @@ class BloomAttention(nn.Module):
         output_attentions=False,
     ):
         if alibi is None:  # TODO OPTIMIZE ALIBI CREATION
+            current_sequence_length = hidden_states.shape[1]
+            if layer_past is not None:
+                current_sequence_length += layer_past[0].shape[1]
             alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
         # hidden_states: [batch_size, seq_length, hidden_size]
         # repeat alibi tensor with the batch size