|
@@ -169,15 +169,15 @@ class BloomAttention(nn.Module):
|
|
|
class BloomMLP(nn.Module):
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
- hidden_size = config.hidden_size
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
if config.compression == 'qint8':
|
|
|
self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
|
|
|
self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
|
|
|
4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
|
|
|
else:
|
|
|
- self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
|
|
|
- self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
|
|
|
+ self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
|
|
|
+ self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
|
|
|
self.hidden_dropout = config.hidden_dropout
|
|
|
self.gelu_impl = BloomGelu()
|
|
|
|