소스 검색

add `self.repo.git_pull` before `self.repo.push_to_hub` (#5)

SaulLu 3 년 전
부모
커밋
7ac2b44807
1개의 변경된 파일12개의 추가작업 그리고 2개의 파일을 삭제
  1. 12 2
      run_aux_peer.py

+ 12 - 2
run_aux_peer.py

@@ -64,8 +64,18 @@ class CheckpointHandler:
         torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()
         logger.info("Started uploading to Model Hub")
-        self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
-        logger.info("Finished uploading to Model Hub")
+        try:
+            # We start by pulling the remote changes (for example a change in the readme file)
+            self.repo.git_pull()
+
+            # Then we add / commmit and push the changes
+            self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
+            logger.info("Finished uploading to Model Hub")
+        except OSError as e:
+            # There may be an error if a push arrives on the remote branch after the pull performed just above it. In
+            # this case the changes will be pushed with the next commit.
+            logger.error(f'The push to hub operation failed with error "{e}"')
+        
 
 
 def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):