justheuristic 3 jaren geleden
bovenliggende
commit
b843328b65
1 gewijzigde bestanden met toevoegingen van 2 en 1 verwijderingen
  1. 2 1
      src/bloom/ops.py

+ 2 - 1
src/bloom/ops.py

@@ -108,6 +108,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
     alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
     return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
 
+
 def dropout_add(x, residual, prob, training):
     """
     Dropout add function
@@ -228,7 +229,7 @@ class BloomScaledSoftmax(nn.Module):
             input = input * self.scale
 
         if mask is None:
-            mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
+            mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
 
         mask = mask.to(input.device)
         causal_mask = (