ソースを参照

mv set_requires_grad to remote_model

dbaranchuk 3 年 前
コミット
21e1f42f04
2 ファイル変更7 行追加7 行削除
  1. 0 7
      src/bloom/model.py
  2. 7 0
      src/client/remote_model.py

+ 0 - 7
src/bloom/model.py

@@ -165,19 +165,12 @@ class BloomModel(BloomPreTrainedModel):
         # Initialize weights and apply final processing
         self.post_init()
 
-        # Forbid accumulate grads for embeddings and layernorm
-        self.set_requires_grad(False)
-
     def get_input_embeddings(self):
         return self.word_embeddings
 
     def set_input_embeddings(self, new_embeddings):
         self.word_embeddings = new_embeddings
 
-    def set_requires_grad(self, value):
-        for p in self.parameters():
-            p.requires_grad = value
-
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
         processor_class=_TOKENIZER_FOR_DOC,

+ 7 - 0
src/client/remote_model.py

@@ -45,6 +45,13 @@ class DistributedBloomModel(BloomModel):
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
         self.h = RemoteSequential(config, dht, config.dht_prefix)
+    
+        # Forbid accumulate grads for embeddings and layernorm
+        self.set_requires_grad(False)
+
+    def set_requires_grad(self, value):
+        for p in self.parameters():
+            p.requires_grad = value
 
 
 class DistributedBloomForCausalLM(BloomForCausalLM):