Преглед изворни кода

Fix llama's lm_head.weight.requires_grad (#330)

By default, `llama's lm_head.weight.requires_grad` was True, but we expect it to be False.
Alexander Borzunov пре 2 година
родитељ
комит
47a2b1ee65

+ 2 - 1
src/petals/client/lm_head.py

@@ -26,7 +26,8 @@ class LMHead(nn.Module):
         super().__init__()
 
         if not config.tie_word_embeddings:
-            self.weight = nn.Parameter(torch.zeros((config.vocab_size, config.hidden_size), requires_grad=False))
+            self.weight = nn.Parameter(torch.zeros(config.vocab_size, config.hidden_size))
+            self.weight.requires_grad = False
         else:
             self.weight = None  # Will be set to get_input_embeddings().weight during loading the model
         self.bias = None

+ 0 - 4
src/petals/client/ptune.py

@@ -40,10 +40,6 @@ class PTuneMixin:
         elif config.tuning_mode:
             raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
 
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
     def get_prompt(self, batch_size):
         prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
         prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)

+ 1 - 1
src/petals/models/bloom/model.py

@@ -35,7 +35,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
 
         self.h = RemoteSequential(config, dht=dht)
 
-        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm
         self.init_prompts(config)
 
     def forward(

+ 1 - 1
src/petals/models/llama/model.py

@@ -33,7 +33,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
 
         self.layers = RemoteSequential(config, dht=dht)
 
-        self.set_requires_grad(False)  # Forbid accumulate grads for embeddings and layernorm
+        self.requires_grad_(False)  # Forbid accumulate grads for embeddings and layernorm
         self.init_prompts(config)
 
     def forward(