|
@@ -52,6 +52,10 @@ class CoordinatorArguments(BaseTrainingArguments):
|
|
|
default=None,
|
|
|
metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
|
|
|
)
|
|
|
+ repo_url: Optional[str] = field(
|
|
|
+ default=None,
|
|
|
+ metadata={"help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"}
|
|
|
+ )
|
|
|
upload_interval: Optional[float] = field(
|
|
|
default=None,
|
|
|
metadata={"help": "Coordinator will upload model once in this many seconds"}
|
|
@@ -67,6 +71,7 @@ class CheckpointHandler:
|
|
|
averager_args: AveragerArguments, dht: hivemind.DHT):
|
|
|
self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
|
|
|
self.repo_path = coordinator_args.repo_path
|
|
|
+ self.repo_url = coordinator_args.repo_url
|
|
|
self.upload_interval = coordinator_args.upload_interval
|
|
|
self.previous_step = -1
|
|
|
|
|
@@ -110,6 +115,7 @@ class CheckpointHandler:
|
|
|
return False
|
|
|
|
|
|
def save_state(self, cur_step):
|
|
|
+ logger.info("Saving state from peers")
|
|
|
self.collaborative_optimizer.load_state_from_peers()
|
|
|
self.previous_step = cur_step
|
|
|
|
|
@@ -122,17 +128,13 @@ class CheckpointHandler:
|
|
|
return False
|
|
|
|
|
|
def upload_checkpoint(self, current_loss):
|
|
|
- self.model.save_pretrained(self.repo_path)
|
|
|
+ logger.info("Saving optimizer")
|
|
|
torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
|
|
|
self.previous_timestamp = time.time()
|
|
|
- try:
|
|
|
- subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
|
|
|
- current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
|
|
|
- subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
|
|
|
- shell=True, check=True, cwd=self.repo_path)
|
|
|
- subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
|
|
|
- except subprocess.CalledProcessError as e:
|
|
|
- logger.warning("Error while uploading model:", e.output)
|
|
|
+ logger.info('Started uploading model to Hub')
|
|
|
+ self.model.push_to_hub(repo_name=self.repo_path, repo_url=self.repo_url,
|
|
|
+ commit_message=f'Step {current_step}, loss {current_loss:.3f}')
|
|
|
+ logger.info('Finished uploading model to Hub')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|