run_aux_peer.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. #!/usr/bin/env python
  2. import threading
  3. import time
  4. import torch
  5. import wandb
  6. import transformers
  7. from transformers import HfArgumentParser
  8. from huggingface_hub import HfFolder, Repository
  9. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  10. import utils
  11. from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
  12. from task import TrainingTask
  13. transformers.utils.logging.set_verbosity_warning()
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger(__name__)
  16. class CheckpointHandler:
  17. def __init__(self, task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  18. self.task, self.peer_args = task, peer_args
  19. self.save_checkpoint_step_interval = peer_args.save_checkpoint_step_interval
  20. self.prefix = peer_args.experiment_prefix
  21. self.local_path = peer_args.local_path
  22. self.upload_interval = peer_args.upload_interval
  23. if self.upload_interval is not None:
  24. self.token = HfFolder.get_token()
  25. self.repo = Repository(
  26. local_dir=self.local_path,
  27. clone_from=peer_args.repo_url,
  28. use_auth_token=self.token,
  29. )
  30. self.previous_step = -1
  31. self.previous_timestamp = time.time()
  32. def should_save_state(self, cur_step):
  33. if self.save_checkpoint_step_interval is None:
  34. return False
  35. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  36. return True
  37. else:
  38. return False
  39. def save_state(self, cur_step):
  40. logger.info("Saving state from peers")
  41. self.task.collaborative_optimizer.load_state_from_peers()
  42. self.previous_step = cur_step
  43. def is_time_to_upload(self):
  44. if self.upload_interval is None:
  45. return False
  46. elif time.time() - self.previous_timestamp >= self.upload_interval:
  47. return True
  48. else:
  49. return False
  50. def upload_checkpoint(self, current_loss):
  51. logger.info("Saving model")
  52. torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
  53. logger.info("Saving optimizer")
  54. torch.save(self.task.collaborative_optimizer.state_dict(), f"{self.local_path}/optimizer_state.pt")
  55. self.previous_timestamp = time.time()
  56. logger.info("Started uploading to Model Hub")
  57. try:
  58. # We start by pulling the remote changes (for example a change in the readme file)
  59. self.repo.git_pull()
  60. # Then we add / commmit and push the changes
  61. self.repo.push_to_hub(commit_message=f"Epoch {self.task.collaborative_optimizer.local_epoch}, loss {current_loss:.3f}")
  62. logger.info("Finished uploading to Model Hub")
  63. except OSError as e:
  64. # There may be an error if a push arrives on the remote branch after the pull performed just above it. In
  65. # this case the changes will be pushed with the next commit.
  66. logger.error(f'The push to hub operation failed with error "{e}"')
  67. def assist_averaging_in_background(task: TrainingTask, peer_args: AuxiliaryPeerArguments):
  68. while True:
  69. time.sleep(peer_args.assist_refresh)
  70. task.collaborative_optimizer.step()
  71. if __name__ == "__main__":
  72. parser = HfArgumentParser((AuxiliaryPeerArguments, HFTrainerArguments, CollaborativeArguments))
  73. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  74. task = TrainingTask(peer_args, trainer_args, collab_args)
  75. dht, collaborative_optimizer = task.dht, task.collaborative_optimizer
  76. if peer_args.wandb_project is not None:
  77. wandb.init(project=peer_args.wandb_project)
  78. current_step = 0
  79. if peer_args.store_checkpoints:
  80. checkpoint_handler = CheckpointHandler(task, peer_args)
  81. if peer_args.assist_in_averaging:
  82. # assert not peer_args.client_mode, "client-mode peers cannot assist in averaging"
  83. # averaging_thread = threading.Thread(
  84. # name="AveragingAuxThread", target=assist_averaging_in_background, args=[task, peer_args], daemon=True)
  85. # averaging_thread.start()
  86. raise NotImplementedError('aux peers with hivemind.optim.experimental are not supported yet')
  87. while True:
  88. metrics_entry = dht.get(peer_args.experiment_prefix + "_metrics", latest=True)
  89. if metrics_entry is not None and len(metrics_entry.value) > 0:
  90. metrics_dict = metrics_entry.value
  91. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  92. latest_step = max(item.step for item in metrics)
  93. if latest_step != current_step:
  94. logger.debug(f"Got metrics from {len(metrics)} peers")
  95. for i, metrics_for_peer in enumerate(metrics):
  96. logger.debug(f"{i} peer {metrics_for_peer}")
  97. current_step = latest_step
  98. alive_peers = 0
  99. sum_loss = 0
  100. num_samples = 0
  101. sum_perf = 0
  102. sum_mini_steps = 0
  103. for item in metrics:
  104. sum_loss += item.loss
  105. alive_peers += 1
  106. sum_perf += item.samples_per_second
  107. num_samples += item.samples_accumulated
  108. sum_mini_steps += item.mini_steps
  109. current_loss = sum_loss / sum_mini_steps
  110. logger.info(f"Epoch #{current_step}\tloss = {current_loss:.5f}")
  111. if peer_args.wandb_project is not None:
  112. wandb.log(
  113. {
  114. "loss": current_loss,
  115. "alive peers": alive_peers,
  116. "samples": num_samples,
  117. "performance": sum_perf,
  118. "step": latest_step,
  119. }
  120. )
  121. if peer_args.store_checkpoints:
  122. if checkpoint_handler.should_save_state(current_step):
  123. checkpoint_handler.save_state(current_step)
  124. if checkpoint_handler.is_time_to_upload():
  125. checkpoint_handler.upload_checkpoint(current_loss)
  126. logger.debug("Peer is still alive...")
  127. time.sleep(peer_args.refresh_period)