callback.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import os.path
  2. from typing import Any
  3. import hivemind
  4. import torch
  5. import transformers
  6. from transformers import TrainingArguments
  7. from arguments import TrainingPeerArguments
  8. from task import TrainingTask
  9. from utils import LocalMetrics, logger
  10. class CollaborativeCallback(transformers.TrainerCallback):
  11. """
  12. This callback monitors and reports collaborative training progress,
  13. In case of a catastrophic failure, it can also revert training to a backup
  14. """
  15. def __init__(self, task: TrainingTask, args: TrainingPeerArguments):
  16. super().__init__()
  17. self.task = task
  18. self.dht, self.collaborative_optimizer = task.dht, task.collaborative_optimizer
  19. self.statistics_expiration = args.statistics_expiration
  20. self.last_reported_collaboration_step = -1
  21. self.samples = 0
  22. self.steps = 0
  23. self.loss = 0
  24. self.total_samples_processed = 0
  25. self.backup_every_steps = args.backup_every_steps
  26. self.state_path = args.state_path
  27. def on_train_begin(
  28. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  29. ):
  30. if os.path.isfile(self.state_path):
  31. self.restore_from_backup(self.state_path)
  32. logger.info("Loaded state")
  33. logger.info("Loading state from peers")
  34. self.collaborative_optimizer.load_state_from_peers()
  35. if os.path.isfile(self.state_path):
  36. self.restore_from_backup(self.state_path, check_step=True)
  37. def on_step_end(
  38. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  39. ):
  40. control.should_log = True
  41. if not self.params_are_finite():
  42. if not os.path.exists(self.state_path):
  43. raise RuntimeError("Encountered broken parameters, but there is no backup to fall back to.")
  44. logger.warning("Parameters are invalid, reloading model from earlier state")
  45. self.restore_from_backup(self.state_path)
  46. return control
  47. if state.log_history:
  48. self.loss += state.log_history[-1]["loss"]
  49. self.steps += 1
  50. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  51. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  52. self.total_samples_processed += self.samples
  53. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  54. statistics = LocalMetrics(
  55. step=self.collaborative_optimizer.local_step,
  56. samples_per_second=samples_per_second,
  57. samples_accumulated=self.samples,
  58. loss=self.loss,
  59. mini_steps=self.steps,
  60. )
  61. logger.info(f"Step {self.collaborative_optimizer.local_step}")
  62. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  63. logger.info(f"Performance: {samples_per_second} samples per second.")
  64. if self.steps:
  65. logger.info(f"Local loss: {self.loss / self.steps}")
  66. self.loss = 0
  67. self.steps = 0
  68. if self.collaborative_optimizer.is_synchronized:
  69. self.dht.store(
  70. key=self.collaborative_optimizer.prefix + "_metrics",
  71. subkey=self.task.local_public_key,
  72. value=statistics.dict(),
  73. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  74. return_future=True,
  75. )
  76. if self.backup_every_steps is not None and \
  77. self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
  78. self.backup_state()
  79. self.samples = self.collaborative_optimizer.local_samples_accumulated
  80. return control
  81. @torch.no_grad()
  82. def params_are_finite(self):
  83. for param in self.task.model.parameters():
  84. if not torch.all(torch.isfinite(param)):
  85. return False
  86. return True
  87. @torch.no_grad()
  88. def backup_state(self) -> Any:
  89. logger.info("Saving backup")
  90. return torch.save(
  91. {
  92. "model": self.task.model.state_dict(),
  93. "training": self.collaborative_optimizer.state_dict(),
  94. "scheduler": self.collaborative_optimizer.scheduler.state_dict(),
  95. "local_step": self.collaborative_optimizer.local_step,
  96. },
  97. self.state_path,
  98. )
  99. @torch.no_grad()
  100. def restore_from_backup(self, path, check_step=False):
  101. state = torch.load(path)
  102. current_step = self.collaborative_optimizer.local_step
  103. backup_step = state['training']['state'][0]['step'] #TODO FIX THIS, use state['local_step']
  104. if not check_step or backup_step >= current_step:
  105. if (
  106. "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
  107. in state["model"]
  108. ):
  109. del state["model"][
  110. "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.cos"
  111. ]
  112. del state["model"][
  113. "albert.encoder.albert_layer_groups.0.albert_layers.0.attention.attention_core.rotary_emb.sin"
  114. ]
  115. if "scheduler" in state:
  116. self.collaborative_optimizer.scheduler.load_state_dict(state["scheduler"])
  117. self.collaborative_optimizer.load_state_dict(state["training"])
  118. self.collaborative_optimizer.averager.local_step = backup_step
  119. self.task.model.load_state_dict(state["model"], strict=False)
  120. logger.info("Restored from a backup")
  121. else:
  122. logger.info("Bypassed restoring state from local backup: backup state is too old.")