|
@@ -108,6 +108,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
|
|
|
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
|
|
|
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
|
|
|
|
|
|
+
|
|
|
def dropout_add(x, residual, prob, training):
|
|
|
"""
|
|
|
Dropout add function
|
|
@@ -228,7 +229,7 @@ class BloomScaledSoftmax(nn.Module):
|
|
|
input = input * self.scale
|
|
|
|
|
|
if mask is None:
|
|
|
- mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
|
|
|
+ mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
|
|
|
|
|
mask = mask.to(input.device)
|
|
|
causal_mask = (
|