|
@@ -6,11 +6,17 @@ See commit history for authorship.
|
|
|
import math
|
|
|
|
|
|
import torch
|
|
|
-import torch.nn.quantized.dynamic.modules.linear
|
|
|
import torch.nn as nn
|
|
|
+import torch.nn.quantized.dynamic.modules.linear
|
|
|
|
|
|
-from src.ops import BloomScaledSoftmax, BloomGelu
|
|
|
-from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
|
|
|
+from src.ops import (
|
|
|
+ BloomGelu,
|
|
|
+ BloomScaledSoftmax,
|
|
|
+ attention_mask_func,
|
|
|
+ dropout_add,
|
|
|
+ pre_process_alibi_for_pad,
|
|
|
+ split_tensor_along_last_dim,
|
|
|
+)
|
|
|
|
|
|
|
|
|
class BloomAttention(nn.Module):
|
|
@@ -43,11 +49,13 @@ class BloomAttention(nn.Module):
|
|
|
self.layer_number,
|
|
|
)
|
|
|
|
|
|
- if config.compression == 'qint8':
|
|
|
+ if config.compression == "qint8":
|
|
|
self.query_key_value = nn.quantized.dynamic.modules.Linear(
|
|
|
- self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
+ self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
|
|
|
+ )
|
|
|
self.dense = nn.quantized.dynamic.modules.Linear(
|
|
|
- self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
+ self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
|
|
|
+ )
|
|
|
else:
|
|
|
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
|
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
|
@@ -120,9 +128,7 @@ class BloomAttention(nn.Module):
|
|
|
|
|
|
# attention scores and attention mask [b, np, sq, sk]
|
|
|
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
|
|
- attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
|
|
|
- value_layer.dtype
|
|
|
- )
|
|
|
+ attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
|
|
|
attention_probs = self.attention_dropout(attention_probs)
|
|
|
|
|
|
if head_mask is not None:
|
|
@@ -170,11 +176,13 @@ class BloomMLP(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.hidden_size = config.hidden_size
|
|
|
- if config.compression == 'qint8':
|
|
|
+ if config.compression == "qint8":
|
|
|
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
|
|
|
- self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
+ self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
|
|
|
+ )
|
|
|
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
|
|
|
- 4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
+ 4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
|
|
|
+ )
|
|
|
else:
|
|
|
self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
|
|
self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|