run_aux_peer.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python
  2. import threading
  3. import time
  4. import torch
  5. import wandb
  6. import transformers
  7. from transformers import HfArgumentParser
  8. from huggingface_hub import HfFolder, Repository
  9. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  10. import utils
  11. from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
  12. from task import TrainingTask
  13. transformers.utils.logging.set_verbosity_warning()
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger(__name__)
  16. class CheckpointHandler:
  17. def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  18. self.task, self.peer_args = task, peer_args
  19. self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
  20. self.prefix = peer_args.experiment_prefix
  21. self.local_path = peer_args.local_path
  22. self.upload_interval = peer_args.upload_interval
  23. if self.upload_interval is not None:
  24. self.token = HfFolder.get_token()
  25. self.repo = Repository(
  26. local_dir=self.local_path,
  27. clone_from=peer_args.repo_url,
  28. use_auth_token=self.token,
  29. )
  30. self.previous_step = -1
  31. self.previous_timestamp = time.time()
  32. def should_save_state(self, cur_step):
  33. if self.save_checkpoint_step_interval is None:
  34. return False
  35. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  36. return True
  37. else:
  38. return False
  39. def save_state(self, cur_step):
  40. logger.info("Saving state from peers")
  41. self.task.collaborative_optimizer.load_state_from_peers()
  42. self.previous_step = cur_step
  43. def is_time_to_upload(self):
  44. if self.upload_interval is None:
  45. return False
  46. elif time.time() - self.previous_timestamp >= self.upload_interval:
  47. return True
  48. else:
  49. return False
  50. def upload_checkpoint(self, current_loss):
  51. logger.info("Saving model")
  52. torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
  53. logger.info("Saving optimizer")
  54. torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
  55. self.previous_timestamp = time.time()
  56. logger.info("Started uploading to Model Hub")
  57. self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_step}, loss {current_loss:.3f}")
  58. logger.info("Finished uploading to Model Hub")
  59. def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  60. while True:
  61. time.sleep(peer_args.assist_refresh)
  62. task.collaborative_optimizer.step_aux()
  63. if __name__ == "__main__":
  64. parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
  65. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  66. task = TrainingTask(peer_args, trainer_args, collab_args)
  67. dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
  68. if peer_args.wandb_project is not None:
  69. wandb.init(project=peer_args.wandb_project)
  70. current_step = 0
  71. if peer_args.store_checkpoints:
  72. checkpoint_handler = CheckpointHandler(task, peer_args)
  73. if peer_args.assist_in_averaging:
  74. assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
  75. averaging_thread = threading.Thread(
  76. name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
  77. averaging_thread.start()
  78. while True:
  79. metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
  80. if metrics_entry is not None and len(metrics_entry.value) > 0:
  81. metrics_dict = metrics_entry.value
  82. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  83. latest_step = max(item.step for item in metrics)
  84. if latest_step != current_step:
  85. logger.debug(f"Got metrics from {len(metrics)} peers")
  86. for i, metrics_for_peer in enumerate(metrics):
  87. logger.debug(f"{i} peer {metrics_for_peer}")
  88. current_step = latest_step
  89. alive_peers = 0
  90. sum_loss = 0
  91. num_samples = 0
  92. sum_perf = 0
  93. sum_mini_steps = 0
  94. for item in metrics:
  95. sum_loss += item.loss
  96. alive_peers += 1
  97. sum_perf += item.samples_per_second
  98. num_samples += item.samples_accumulated
  99. sum_mini_steps += item.mini_steps
  100. current_loss = sum_loss / sum_mini_steps
  101. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  102. if peer_args.wandb_project is not None:
  103. wandb.log(
  104. {
  105. "loss": current_loss,
  106. "alive peers": alive_peers,
  107. "samples": num_samples,
  108. "performance": sum_perf,
  109. "step": latest_step,
  110. }
  111. )
  112. if peer_args.store_checkpoints:
  113. if checkpoint_handler.should_save_state(current_step):
  114. checkpoint_handler.save_state(current_step)
  115. if checkpoint_handler.is_time_to_upload():
  116. checkpoint_handler.upload_checkpoint(current_loss)
  117. logger.debug("Peer is still alive...")
  118. time.sleep(peer_args.refresh_period)