run_aux_peer.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #!/usr/bin/env python3
  2. import time
  3. import torch
  4. import wandb
  5. import transformers
  6. from transformers import HfArgumentParser
  7. from huggingface_hub import HfFolder, Repository
  8. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  9. import utils
  10. from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
  11. from task import TrainingTask
  12. transformers.utils.logging.set_verbosity_warning()
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger(__name__)
  15. class CheckpointHandler:
  16. def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  17. self.task, self.peer_args = task, peer_args
  18. self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
  19. self.prefix = peer_args.experiment_prefix
  20. self.local_path = peer_args.local_path
  21. self.upload_interval = peer_args.upload_interval
  22. if self.upload_interval is not None:
  23. assert task.authorizer is not None, 'Model uploading needs Hugging Face auth to be enabled'
  24. self.repo = Repository(
  25. local_dir=self.local_path,
  26. clone_from=peer_args.repo_url,
  27. use_auth_token=task.authorizer.hf_user_access_token,
  28. )
  29. self.last_upload_time = None
  30. self.previous_step = -1
  31. def should_save_state(self, cur_step):
  32. if self.save_checkpoint_step_interval is None:
  33. return False
  34. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  35. return True
  36. else:
  37. return False
  38. def save_state(self, cur_step):
  39. logger.info("Saving state from peers")
  40. self.task.collaborative_optimizer.load_state_from_peers()
  41. self.previous_step = cur_step
  42. def is_time_to_upload(self):
  43. if self.upload_interval is None:
  44. return False
  45. elif self.last_upload_time is None or time.time() - self.last_upload_time >= self.upload_interval:
  46. return True
  47. else:
  48. return False
  49. def upload_checkpoint(self, current_loss):
  50. self.last_upload_time = time.time()
  51. logger.info("Saving model")
  52. torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
  53. logger.info("Saving optimizer")
  54. torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
  55. logger.info("Started uploading to Model Hub")
  56. try:
  57. # We start by pulling the remote changes (for example a change in the readme file)
  58. self.repo.git_pull()
  59. # Then we add / commmit and push the changes
  60. self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
  61. logger.info("Finished uploading to Model Hub")
  62. except Exception:
  63. logger.exception("Uploading the checkpoint to HF Model Hub failed:")
  64. logger.warning("Ensure that your access token is valid and has WRITE permissions")
  65. def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  66. while True:
  67. time.sleep(peer_args.assist_refresh)
  68. task.collaborative_optimizer.step()
  69. if __name__ == "__main__":
  70. parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
  71. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  72. task = TrainingTask(peer_args, trainer_args, collab_args)
  73. dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
  74. if peer_args.wandb_project is not None:
  75. wandb.init(project=peer_args.wandb_project)
  76. current_step = 0
  77. if peer_args.store_checkpoints:
  78. checkpoint_handler = CheckpointHandler(task, peer_args)
  79. if peer_args.assist_in_averaging:
  80. # assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
  81. # averaging_thread = threading.Thread(
  82. # name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
  83. # averaging_thread.start()
  84. raise NotImplementedError('aux peers with hivemind.optim.experimental are not supported yet')
  85. while True:
  86. metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
  87. if metrics_entry is not None and len(metrics_entry.value) > 0:
  88. metrics_dict = metrics_entry.value
  89. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  90. latest_step = max(item.step for item in metrics)
  91. if latest_step != current_step:
  92. logger.debug(f"Got metrics from {len(metrics)} peers")
  93. for i, metrics_for_peer in enumerate(metrics):
  94. logger.debug(f"{i} peer {metrics_for_peer}")
  95. current_step = latest_step
  96. alive_peers = 0
  97. sum_loss = 0
  98. num_samples = 0
  99. sum_perf = 0
  100. sum_mini_steps = 0
  101. for item in metrics:
  102. sum_loss += item.loss
  103. alive_peers += 1
  104. sum_perf += item.samples_per_second
  105. num_samples += item.samples_accumulated
  106. sum_mini_steps += item.mini_steps
  107. current_loss = sum_loss / sum_mini_steps
  108. logger.info(f"Epoch #{current_step}\tloss = {current_loss:.5f}")
  109. if peer_args.wandb_project is not None:
  110. wandb.log(
  111. {
  112. "loss": current_loss,
  113. "alive peers": alive_peers,
  114. "samples": num_samples,
  115. "performance": sum_perf,
  116. "step": latest_step,
  117. }
  118. )
  119. if peer_args.store_checkpoints:
  120. if checkpoint_handler.should_save_state(current_step):
  121. checkpoint_handler.save_state(current_step)
  122. if checkpoint_handler.is_time_to_upload():
  123. checkpoint_handler.upload_checkpoint(current_loss)
  124. logger.debug("Peer is still alive...")
  125. time.sleep(peer_args.refresh_period)