|
@@ -28,9 +28,9 @@ class WrappedBloomBlock(BloomBlock):
|
|
|
past_length = 0 if layer_past is None else layer_past[0].shape[-1]
|
|
|
if alibi is None:
|
|
|
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
|
|
|
- attention_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
|
+ causal_mask = self._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)
|
|
|
return super().forward(
|
|
|
- hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
|
+ hidden_states, *args, attention_mask=causal_mask, alibi=alibi, layer_past=layer_past, **kwargs
|
|
|
)
|
|
|
|
|
|
def _prepare_attn_mask(
|