浏览代码

Upload the model with push_to_hub in examples (#297)

* Change model uploading method

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexey Bukhtiyarov 4 年之前
父节点
当前提交
2436a3bed9
共有 2 个文件被更改,包括 12 次插入10 次删除
  1. 1 1
      examples/albert/requirements.txt
  2. 11 9
      examples/albert/run_training_monitor.py

+ 1 - 1
examples/albert/requirements.txt

@@ -1,4 +1,4 @@
-transformers>=4.5.1
+transformers>=4.6.0
 datasets>=1.5.0
 torch_optimizer>=0.1.0
 wandb>=0.10.26

+ 11 - 9
examples/albert/run_training_monitor.py

@@ -52,6 +52,10 @@ class CoordinatorArguments(BaseTrainingArguments):
         default=None,
         metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
     )
+    repo_url: Optional[str] = field(
+        default=None,
+        metadata={"help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"}
+    )
     upload_interval: Optional[float] = field(
         default=None,
         metadata={"help": "Coordinator will upload model once in this many seconds"}
@@ -67,6 +71,7 @@ class CheckpointHandler:
                  averager_args: AveragerArguments, dht: hivemind.DHT):
         self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
         self.repo_path = coordinator_args.repo_path
+        self.repo_url = coordinator_args.repo_url
         self.upload_interval = coordinator_args.upload_interval
         self.previous_step = -1
 
@@ -110,6 +115,7 @@ class CheckpointHandler:
             return False
 
     def save_state(self, cur_step):
+        logger.info("Saving state from peers")
         self.collaborative_optimizer.load_state_from_peers()
         self.previous_step = cur_step
 
@@ -122,17 +128,13 @@ class CheckpointHandler:
             return False
 
     def upload_checkpoint(self, current_loss):
-        self.model.save_pretrained(self.repo_path)
+        logger.info("Saving optimizer")
         torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
-        try:
-            subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
-            current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
-            subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
-                           shell=True, check=True, cwd=self.repo_path)
-            subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
-        except subprocess.CalledProcessError as e:
-            logger.warning("Error while uploading model:", e.output)
+        logger.info('Started uploading model to Hub')
+        self.model.push_to_hub(repo_name=self.repo_path, repo_url=self.repo_url,
+                               commit_message=f'Step {current_step}, loss {current_loss:.3f}')
+        logger.info('Finished uploading model to Hub')
 
 
 if __name__ == '__main__':