run_aux_peer.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #!/usr/bin/env python
  2. import threading
  3. import time
  4. import torch
  5. import wandb
  6. import transformers
  7. from transformers import HfArgumentParser
  8. from huggingface_hub import HfFolder, Repository
  9. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  10. import utils
  11. from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
  12. from task import TrainingTask
  13. transformers.utils.logging.disable_default_handler()
  14. transformers.utils.logging.enable_propagation()
  15. use_hivemind_log_handler("in_root_logger")
  16. logger = get_logger(__name__)
  17. class CheckpointHandler:
  18. def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  19. self.task, self.peer_args = task, peer_args
  20. self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
  21. self.prefix = peer_args.experiment_prefix
  22. self.local_path = peer_args.local_path
  23. self.upload_interval = peer_args.upload_interval
  24. if self.upload_interval is not None:
  25. self.token = HfFolder.get_token()
  26. self.repo = Repository(
  27. local_dir=self.local_path,
  28. clone_from=peer_args.repo_url,
  29. use_auth_token=self.token,
  30. )
  31. self.previous_step = -1
  32. self.previous_timestamp = time.time()
  33. def should_save_state(self, cur_step):
  34. if self.save_checkpoint_step_interval is None:
  35. return False
  36. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  37. return True
  38. else:
  39. return False
  40. def save_state(self, cur_step):
  41. logger.info("Saving state from peers")
  42. self.task.collaborative_optimizer.load_state_from_peers()
  43. self.previous_step = cur_step
  44. def is_time_to_upload(self):
  45. if self.upload_interval is None:
  46. return False
  47. elif time.time() - self.previous_timestamp >= self.upload_interval:
  48. return True
  49. else:
  50. return False
  51. def upload_checkpoint(self, current_loss):
  52. logger.info("Saving model")
  53. torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
  54. logger.info("Saving optimizer")
  55. torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
  56. self.previous_timestamp = time.time()
  57. logger.info("Started uploading to Model Hub")
  58. self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_step}, loss {current_loss:.3f}")
  59. logger.info("Finished uploading to Model Hub")
  60. def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  61. while True:
  62. time.sleep(peer_args.assist_refresh)
  63. task.collaborative_optimizer.step_aux()
  64. if __name__ == "__main__":
  65. parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
  66. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  67. task = TrainingTask(peer_args, trainer_args, collab_args)
  68. dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
  69. if peer_args.wandb_project is not None:
  70. wandb.init(project=peer_args.wandb_project)
  71. current_step = 0
  72. if peer_args.store_checkpoints:
  73. checkpoint_handler = CheckpointHandler(task, peer_args)
  74. if peer_args.assist_in_averaging:
  75. assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
  76. averaging_thread = threading.Thread(
  77. name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
  78. averaging_thread.start()
  79. while True:
  80. metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
  81. if metrics_entry is not None and len(metrics_entry.value) > 0:
  82. metrics_dict = metrics_entry.value
  83. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  84. latest_step = max(item.step for item in metrics)
  85. if latest_step != current_step:
  86. logger.debug(f"Got metrics from {len(metrics)} peers")
  87. for i, metrics_for_peer in enumerate(metrics):
  88. logger.debug(f"{i} peer {metrics_for_peer}")
  89. current_step = latest_step
  90. alive_peers = 0
  91. sum_loss = 0
  92. num_samples = 0
  93. sum_perf = 0
  94. sum_mini_steps = 0
  95. for item in metrics:
  96. sum_loss += item.loss
  97. alive_peers += 1
  98. sum_perf += item.samples_per_second
  99. num_samples += item.samples_accumulated
  100. sum_mini_steps += item.mini_steps
  101. current_loss = sum_loss / sum_mini_steps
  102. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  103. if peer_args.wandb_project is not None:
  104. wandb.log(
  105. {
  106. "loss": current_loss,
  107. "alive peers": alive_peers,
  108. "samples": num_samples,
  109. "performance": sum_perf,
  110. "step": latest_step,
  111. }
  112. )
  113. if peer_args.store_checkpoints:
  114. if checkpoint_handler.should_save_state(current_step):
  115. checkpoint_handler.save_state(current_step)
  116. if checkpoint_handler.is_time_to_upload():
  117. checkpoint_handler.upload_checkpoint(current_loss)
  118. logger.debug("Peer is still alive...")
  119. time.sleep(peer_args.refresh_period)