|
@@ -8,6 +8,7 @@ import math
|
|
|
import torch
|
|
|
import torch.autograd
|
|
|
from torch import nn
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
|
@@ -55,13 +56,14 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
|
|
|
)
|
|
|
|
|
|
|
|
|
-def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
|
|
+def build_alibi_tensor(
|
|
|
+ max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
|
|
|
+) -> torch.Tensor:
|
|
|
"""
|
|
|
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
|
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
|
|
`softmax(l+a) = softmax(l)`. Based on
|
|
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
|
|
|
-
|
|
|
Args:
|
|
|
Returns tensor shaped (n_head, 1, max_seq_len)
|
|
|
max_seq_len: (`int`, *required*):
|
|
@@ -70,67 +72,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
|
|
|
number of heads
|
|
|
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
|
|
dtype of the output tensor
|
|
|
+ device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
|
|
|
+ device of the output alibi tensor
|
|
|
"""
|
|
|
+ closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
|
|
|
+ base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
|
|
|
+ powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
|
|
|
+ slopes = torch.pow(base, powers)
|
|
|
+
|
|
|
+ if closest_power_of_2 != n_head:
|
|
|
+ extra_base = torch.tensor(
|
|
|
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
|
|
|
+ )
|
|
|
+ num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
|
|
|
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
|
|
|
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
|
|
|
|
|
- def get_slopes(n):
|
|
|
- def get_slopes_power_of_2(n):
|
|
|
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
|
|
- ratio = start
|
|
|
- return [start * ratio**i for i in range(n)]
|
|
|
-
|
|
|
- if math.log2(n).is_integer():
|
|
|
- return get_slopes_power_of_2(n)
|
|
|
- else:
|
|
|
- closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
|
|
- return (
|
|
|
- get_slopes_power_of_2(closest_power_of_2)
|
|
|
- + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
|
|
- )
|
|
|
-
|
|
|
- slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
|
|
|
- arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
|
|
|
- alibi = slopes * arange_tensor.expand(n_head, -1, -1)
|
|
|
-
|
|
|
- alibi = alibi.to(dtype)
|
|
|
-
|
|
|
- return alibi
|
|
|
+ lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
|
|
|
+ return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
|
|
|
|
|
|
|
|
|
-def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
|
|
|
+def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
|
|
|
"""
|
|
|
Args:
|
|
|
Pre-process the alibi tensor for padding.
|
|
|
alibi: ([`torch.tensor`], *required*):
|
|
|
alibi tensor to pre-process
|
|
|
attention_mask: ([`torch.tensor`], *required*):
|
|
|
- attention mask to pre-process"""
|
|
|
-
|
|
|
- # Sanity check if we are not inferring less tokens than the total sequence length
|
|
|
- # This usually happens when the inference is done with past_key_values
|
|
|
- # In this case we re-create the alibi tensor with the correct sequence length
|
|
|
- if attention_mask.shape[-1] != alibi.shape[-1]:
|
|
|
- alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
|
|
|
- attention_mask.shape[0], 1, 1
|
|
|
- )
|
|
|
- # Get the indexes of the padding tokens
|
|
|
- index_x0, index_y0 = torch.where(attention_mask == 0.0)
|
|
|
- index_x1, index_y1 = torch.where(attention_mask == 1.0)
|
|
|
-
|
|
|
- # Clone the embeddings - we can detach because the embeddings are not learned
|
|
|
- # Get a refence tensor
|
|
|
- slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
|
|
|
-
|
|
|
- # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
|
|
|
- # Only where you do not have padding. Replace padding tokens by zeros
|
|
|
- # This operation can be seen as a shifting operation.
|
|
|
- for i, index in enumerate(torch.unique(index_x0)):
|
|
|
- slice_to_modify = torch.zeros_like(slice_reference_alibi)
|
|
|
- index_shift = index_y1[index_x1 == index]
|
|
|
- shift_value = len(index_shift)
|
|
|
- slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
|
|
|
- alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
|
|
|
- return alibi
|
|
|
-
|
|
|
+ attention mask to pre-process
|
|
|
+ """
|
|
|
+ assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
|
|
|
+ unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
|
|
|
+ # ^-- [batch, max_len], values correspond to element indices after removing padding
|
|
|
+ # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
|
|
|
+ alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
|
|
|
+ return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
|
|
|
|
|
|
def dropout_add(x, residual, prob, training):
|
|
|
"""
|
|
@@ -251,17 +227,17 @@ class BloomScaledSoftmax(nn.Module):
|
|
|
if self.scale is not None:
|
|
|
input = input * self.scale
|
|
|
|
|
|
- if mask is not None:
|
|
|
- mask = mask.to(input.device)
|
|
|
- causal_mask = (
|
|
|
- torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
|
|
- .view(1, 1, max_positions, max_positions)
|
|
|
- .to(input.device)
|
|
|
- )
|
|
|
- mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
|
|
- probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
|
|
- else:
|
|
|
- probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
|
|
|
+ if mask is None:
|
|
|
+ mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
|
|
|
+
|
|
|
+ mask = mask.to(input.device)
|
|
|
+ causal_mask = (
|
|
|
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
|
|
+ .view(1, 1, max_positions, max_positions)
|
|
|
+ .to(input.device)
|
|
|
+ )
|
|
|
+ mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
|
|
+ probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
|
|
|
|
|
if input_in_16bit and self.softmax_in_fp32:
|
|
|
probs = probs.to(dtype=input_dtype)
|