|
@@ -64,8 +64,18 @@ class CheckpointHandler:
|
|
|
torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
|
|
|
self.previous_timestamp = time.time()
|
|
|
logger.info("Started uploading to Model Hub")
|
|
|
- self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
|
|
|
- logger.info("Finished uploading to Model Hub")
|
|
|
+ try:
|
|
|
+ # We start by pulling the remote changes (for example a change in the readme file)
|
|
|
+ self.repo.git_pull()
|
|
|
+
|
|
|
+ # Then we add / commmit and push the changes
|
|
|
+ self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
|
|
|
+ logger.info("Finished uploading to Model Hub")
|
|
|
+ except OSError as e:
|
|
|
+ # There may be an error if a push arrives on the remote branch after the pull performed just above it. In
|
|
|
+ # this case the changes will be pushed with the next commit.
|
|
|
+ logger.error(f'The push to hub operation failed with error "{e}"')
|
|
|
+
|
|
|
|
|
|
|
|
|
def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
|