瀏覽代碼

save hidden size

justheuristic 3 年之前
父節點
當前提交
8f5d022d18
共有 1 個文件被更改,包括 3 次插入3 次删除
  1. 3 3
      src/bloom/block.py

+ 3 - 3
src/bloom/block.py

@@ -204,12 +204,12 @@ class BloomMLP(nn.Module):
 class BloomBlock(nn.Module):
     def __init__(self, config, layer_number=None):
         super().__init__()
-        hidden_size = config.hidden_size
+        self.hidden_size = config.hidden_size
 
-        self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
         self.n_head = config.n_head
         self.self_attention = BloomAttention(config, layer_number=layer_number)
-        self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
 
         self.mlp = BloomMLP(config)