Browse Source

causal mask by default

justheuristic 3 years ago
parent
commit
5d8f7be546
3 changed files with 54 additions and 80 deletions
  1. 13 15
      src/bloom/block.py
  2. 40 64
      src/bloom/ops.py
  3. 1 1
      src/server/backend.py

+ 13 - 15
src/bloom/block.py

@@ -74,18 +74,19 @@ class BloomAttention(nn.Module):
         use_cache=False,
         output_attentions=False,
     ):
-        if alibi is None:  # TODO OPTIMIZE ALIBI CREATION
-            current_sequence_length = hidden_states.shape[1]
-            if layer_past is not None:
-                current_sequence_length += layer_past[0].shape[1]
-            alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
-        # hidden_states: [batch_size, seq_length, hidden_size]
-        # repeat alibi tensor with the batch size
-        alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)  # TODO eliminate cpu-gpu transfer!
+        if alibi is None:
+            current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
+            alibi = build_alibi_tensor(
+                current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
+            )
 
+        # hidden_states: [batch_size, seq_length, hidden_size]
         # apply preprocessing if the input is padded
-        if attention_mask is not None and 0 in attention_mask:  # TODO REMOVE CUDA SYNC
-            alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
+        if attention_mask is not None:
+            alibi = pre_process_alibi_for_pad(alibi, attention_mask)
+        # otherwise repeat alibi tensor with the batch size
+        else:
+            alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
 
         mixed_x_layer = self.query_key_value(hidden_states)
 
@@ -115,19 +116,16 @@ class BloomAttention(nn.Module):
         # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
         key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
 
-        # slice alibi tensor until the query length
-        sliced_alibi = alibi[: output_size[0] * output_size[1], :, : output_size[3]]
-
         # Raw attention scores. [batch_size * num_heads, q_length, k_length]
         beta = 1.0 / self.layer_number
 
         matmul_result = torch.baddbmm(
-            sliced_alibi,
+            alibi,
             query_layer.transpose(1, 0),
             key_layer.transpose(1, 0).transpose(1, 2),
             beta=beta,
             alpha=(1.0 / self.norm_factor),
-        )  # TODO if end up creating alibi inside forward, consider setting out=sliced_alibi for memory efficiency
+        )
 
         # change view to [batch_size, num_heads, q_length, k_length]
         attention_scores = matmul_result.view(*output_size)

+ 40 - 64
src/bloom/ops.py

@@ -8,6 +8,7 @@ import math
 import torch
 import torch.autograd
 from torch import nn
+import torch.nn.functional as F
 
 
 def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
@@ -55,13 +56,14 @@ def attention_mask_func(attention_scores, attention_mask, causal_mask):
     )
 
 
-def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
+def build_alibi_tensor(
+    max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
+) -> torch.Tensor:
     """
     Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
     relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
     `softmax(l+a) = softmax(l)`. Based on
     https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
-
     Args:
     Returns tensor shaped (n_head, 1, max_seq_len)
         max_seq_len: (`int`, *required*):
@@ -70,67 +72,41 @@ def build_alibi_tensor(max_seq_len, n_head, dtype=torch.bfloat16):
             number of heads
         dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
             dtype of the output tensor
+        device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
+            device of the output alibi tensor
     """
+    closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
+    base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
+    powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
+    slopes = torch.pow(base, powers)
+
+    if closest_power_of_2 != n_head:
+        extra_base = torch.tensor(
+            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
+        )
+        num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
+        extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
+        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
 
-    def get_slopes(n):
-        def get_slopes_power_of_2(n):
-            start = 2 ** (-(2 ** -(math.log2(n) - 3)))
-            ratio = start
-            return [start * ratio**i for i in range(n)]
-
-        if math.log2(n).is_integer():
-            return get_slopes_power_of_2(n)
-        else:
-            closest_power_of_2 = 2 ** math.floor(math.log2(n))
-            return (
-                get_slopes_power_of_2(closest_power_of_2)
-                + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
-            )
-
-    slopes = torch.Tensor(get_slopes(n_head)).unsqueeze(1).unsqueeze(1)
-    arange_tensor = torch.arange(max_seq_len).unsqueeze(0).unsqueeze(0)
-    alibi = slopes * arange_tensor.expand(n_head, -1, -1)
-
-    alibi = alibi.to(dtype)
-
-    return alibi
+    lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
+    return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
 
 
-def pre_process_alibi_for_pad(alibi, attention_mask, num_heads):
+def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
     """
     Args:
     Pre-process the alibi tensor for padding.
         alibi: ([`torch.tensor`], *required*):
             alibi tensor to pre-process
         attention_mask: ([`torch.tensor`], *required*):
-            attention mask to pre-process"""
-
-    # Sanity check if we are not inferring less tokens than the total sequence length
-    # This usually happens when the inference is done with past_key_values
-    # In this case we re-create the alibi tensor with the correct sequence length
-    if attention_mask.shape[-1] != alibi.shape[-1]:
-        alibi = build_alibi_tensor(attention_mask.shape[-1], num_heads, alibi.dtype).repeat(
-            attention_mask.shape[0], 1, 1
-        )
-    # Get the indexes of the padding tokens
-    index_x0, index_y0 = torch.where(attention_mask == 0.0)
-    index_x1, index_y1 = torch.where(attention_mask == 1.0)
-
-    # Clone the embeddings  - we can detach because the embeddings are not learned
-    # Get a refence tensor
-    slice_reference_alibi = build_alibi_tensor(alibi.shape[-1], num_heads, alibi.dtype)
-
-    # Loop over the batch where the padding is and replace the alibi tensor by the reference tensor
-    # Only where you do not have padding. Replace padding tokens by zeros
-    # This operation can be seen as a shifting operation.
-    for i, index in enumerate(torch.unique(index_x0)):
-        slice_to_modify = torch.zeros_like(slice_reference_alibi)
-        index_shift = index_y1[index_x1 == index]
-        shift_value = len(index_shift)
-        slice_to_modify[:, :, index_shift] = slice_reference_alibi[:, :, :shift_value]
-        alibi[index * num_heads : (index + 1) * num_heads] = slice_to_modify
-    return alibi
-
+            attention mask to pre-process
+    """
+    assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
+    unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
+    # ^-- [batch, max_len], values correspond to element indices after removing padding
+    # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
+    alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
+    return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
 
 def dropout_add(x, residual, prob, training):
     """
@@ -251,17 +227,17 @@ class BloomScaledSoftmax(nn.Module):
         if self.scale is not None:
             input = input * self.scale
 
-        if mask is not None:
-            mask = mask.to(input.device)
-            causal_mask = (
-                torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
-                .view(1, 1, max_positions, max_positions)
-                .to(input.device)
-            )
-            mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
-            probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
-        else:
-            probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
+        if mask is None:
+            mask = torch.ones(input.shape[:2], dtype=torch.bool, device=input.device)
+
+        mask = mask.to(input.device)
+        causal_mask = (
+            torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
+            .view(1, 1, max_positions, max_positions)
+            .to(input.device)
+        )
+        mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
+        probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
 
         if input_in_16bit and self.softmax_in_fp32:
             probs = probs.to(dtype=input_dtype)

+ 1 - 1
src/server/backend.py

@@ -35,7 +35,7 @@ class TransformerBackend(ModuleBackend):
             print("METADATA:", cache_metadata)
             assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
             layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
-            print(past_k.shape, past_v.shape)
+            print('PAST', past_k.shape, past_v.shape)
             hidden_states, (new_k, new_v) = self.module.forward(hidden_states, layer_past=layer_past, use_cache=True)
 
             # todo remove these asserts once we pass all tests