|
@@ -73,14 +73,14 @@ class BloomAttention(nn.Module):
|
|
|
use_cache=False,
|
|
|
output_attentions=False,
|
|
|
):
|
|
|
- if alibi is None:
|
|
|
+ if alibi is None: # TODO OPTIMIZE ALIBI CREATION
|
|
|
alibi = build_alibi_tensor(hidden_states.shape[1], n_head=self.num_heads, dtype=hidden_states.dtype)
|
|
|
# hidden_states: [batch_size, seq_length, hidden_size]
|
|
|
# repeat alibi tensor with the batch size
|
|
|
- alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device)
|
|
|
+ alibi = alibi.repeat(hidden_states.shape[0], 1, 1).to(hidden_states.device) # TODO eliminate cpu-gpu transfer!
|
|
|
|
|
|
# apply preprocessing if the input is padded
|
|
|
- if attention_mask is not None and 0 in attention_mask:
|
|
|
+ if attention_mask is not None and 0 in attention_mask: # TODO REMOVE CUDA SYNC
|
|
|
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
|
|
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states)
|
|
@@ -123,7 +123,7 @@ class BloomAttention(nn.Module):
|
|
|
key_layer.transpose(1, 0).transpose(1, 2),
|
|
|
beta=beta,
|
|
|
alpha=(1.0 / self.norm_factor),
|
|
|
- )
|
|
|
+ ) # TODO if end up creating alibi inside forward, consider setting out=sliced_alibi for memory efficiency
|
|
|
|
|
|
# change view to [batch_size, num_heads, q_length, k_length]
|
|
|
attention_scores = matmul_result.view(*output_size)
|