瀏覽代碼

keep hidden_size as property

justheuristic 3 年之前
父節點
當前提交
ffb56a65ed
共有 1 個文件被更改,包括 3 次插入3 次删除
  1. 3 3
      src/block.py

+ 3 - 3
src/block.py

@@ -169,15 +169,15 @@ class BloomAttention(nn.Module):
 class BloomMLP(nn.Module):
     def __init__(self, config):
         super().__init__()
-        hidden_size = config.hidden_size
+        self.hidden_size = config.hidden_size
         if config.compression == 'qint8':
             self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
                 self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
             self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
                 4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
         else:
-            self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
-            self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
+            self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
+            self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
         self.hidden_dropout = config.hidden_dropout
         self.gelu_impl = BloomGelu()