Parcourir la source

save hidden size

justheuristic il y a 3 ans
Parent
commit
8f5d022d18
1 fichiers modifiés avec 3 ajouts et 3 suppressions
  1. 3 3
      src/bloom/block.py

+ 3 - 3
src/bloom/block.py

@@ -204,12 +204,12 @@ class BloomMLP(nn.Module):
 class BloomBlock(nn.Module):
     def __init__(self, config, layer_number=None):
         super().__init__()
-        hidden_size = config.hidden_size
+        self.hidden_size = config.hidden_size
 
-        self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
         self.n_head = config.n_head
         self.self_attention = BloomAttention(config, layer_number=layer_number)
-        self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
 
         self.mlp = BloomMLP(config)