run_aux_peer.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #!/usr/bin/env python
  2. import threading
  3. import time
  4. import torch
  5. import wandb
  6. from transformers import HfArgumentParser
  7. from huggingface_hub import HfFolder, Repository
  8. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  9. import utils
  10. from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
  11. from task import TrainingTask
  12. use_hivemind_log_handler("in_root_logger")
  13. logger = get_logger()
  14. class CheckpointHandler:
  15. def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  16. self.task, self.peer_args = task, peer_args
  17. self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
  18. self.prefix = peer_args.experiment_prefix
  19. self.local_path = peer_args.local_path
  20. self.upload_interval = peer_args.upload_interval
  21. if self.upload_interval is not None:
  22. self.token = HfFolder.get_token()
  23. self.repo = Repository(
  24. local_dir=self.local_path,
  25. clone_from=peer_args.repo_url,
  26. use_auth_token=self.token,
  27. )
  28. self.previous_step = -1
  29. self.previous_timestamp = time.time()
  30. def should_save_state(self, cur_step):
  31. if self.save_checkpoint_step_interval is None:
  32. return False
  33. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  34. return True
  35. else:
  36. return False
  37. def save_state(self, cur_step):
  38. logger.info("Saving state from peers")
  39. self.task.collaborative_optimizer.load_state_from_peers()
  40. self.previous_step = cur_step
  41. def is_time_to_upload(self):
  42. if self.upload_interval is None:
  43. return False
  44. elif time.time() - self.previous_timestamp >= self.upload_interval:
  45. return True
  46. else:
  47. return False
  48. def upload_checkpoint(self, current_loss):
  49. logger.info("Saving model")
  50. self.task.model.save_pretrained(self.local_path)
  51. logger.info("Saving optimizer")
  52. torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
  53. self.previous_timestamp = time.time()
  54. logger.info("Started uploading to Model Hub")
  55. self.repo.push_to_hub(commit_message=f"Step {self.task.collaborative_optimizer.local_step}, loss {current_loss:.3f}")
  56. logger.info("Finished uploading to Model Hub")
  57. def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  58. while True:
  59. time.sleep(peer_args.assist_refresh)
  60. task.collaborative_optimizer.step_aux()
  61. if __name__ == "__main__":
  62. parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
  63. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  64. task = TrainingTask(peer_args, trainer_args, collab_args)
  65. dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
  66. if peer_args.wandb_project is not None:
  67. wandb.init(project=peer_args.wandb_project)
  68. current_step = 0
  69. if peer_args.store_checkpoints:
  70. checkpoint_handler = CheckpointHandler(task, peer_args)
  71. if peer_args.assist_in_averaging:
  72. assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
  73. averaging_thread = threading.Thread(
  74. name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
  75. averaging_thread.start()
  76. while True:
  77. metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
  78. if metrics_entry is not None and len(metrics_entry.value) > 0:
  79. metrics_dict = metrics_entry.value
  80. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  81. latest_step = max(item.step for item in metrics)
  82. if latest_step != current_step:
  83. logger.debug(f"Got metrics from {len(metrics)} peers")
  84. for i, metrics_for_peer in enumerate(metrics):
  85. logger.debug(f"{i} peer {metrics_for_peer}")
  86. current_step = latest_step
  87. alive_peers = 0
  88. sum_loss = 0
  89. num_samples = 0
  90. sum_perf = 0
  91. sum_mini_steps = 0
  92. for item in metrics:
  93. sum_loss += item.loss
  94. alive_peers += 1
  95. sum_perf += item.samples_per_second
  96. num_samples += item.samples_accumulated
  97. sum_mini_steps += item.mini_steps
  98. current_loss = sum_loss / sum_mini_steps
  99. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  100. if peer_args.wandb_project is not None:
  101. wandb.log(
  102. {
  103. "loss": current_loss,
  104. "alive peers": alive_peers,
  105. "samples": num_samples,
  106. "performance": sum_perf,
  107. "step": latest_step,
  108. }
  109. )
  110. if peer_args.store_checkpoints:
  111. if checkpoint_handler.should_save_state(current_step):
  112. checkpoint_handler.save_state(current_step)
  113. if checkpoint_handler.is_time_to_upload():
  114. checkpoint_handler.upload_checkpoint(current_loss)
  115. logger.debug("Peer is still alive...")
  116. time.sleep(peer_args.refresh_period)