run_training_monitor.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. #!/usr/bin/env python3
  2. import time
  3. from dataclasses import asdict, dataclass, field
  4. from ipaddress import ip_address
  5. from typing import Optional
  6. import requests
  7. import torch
  8. import wandb
  9. from torch_optimizer import Lamb
  10. from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
  11. import hivemind
  12. from hivemind.optim.state_averager import TrainingStateAverager
  13. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  14. from hivemind.utils.networking import log_visible_maddrs
  15. import utils
  16. from arguments import AveragerArguments, BaseTrainingArguments, OptimizerArguments
  17. use_hivemind_log_handler("in_root_logger")
  18. logger = get_logger(__name__)
  19. @dataclass
  20. class TrainingMonitorArguments(BaseTrainingArguments):
  21. """
  22. Note: You might want to have several initial peers so that if one dies,
  23. new workers still can join the collaboration via alive initial peers' addresses.
  24. Specify initial_peers argument for that purpose
  25. """
  26. use_google_dns: bool = field(
  27. default=False,
  28. metadata={
  29. "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
  30. },
  31. )
  32. refresh_period: float = field(default=30, metadata={"help": "Period (in seconds) for fetching the keys from DHT"})
  33. wandb_project: Optional[str] = field(
  34. default=None, metadata={"help": "Name of Weights & Biases project to report the training progress to"}
  35. )
  36. store_checkpoints: bool = field(default=True, metadata={"help": "If False, disables periodic checkpoint saving"})
  37. save_checkpoint_step_interval: int = field(
  38. default=5, metadata={"help": "Frequency (in steps) of fetching and saving state from peers"}
  39. )
  40. model_config_path: str = field(
  41. default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
  42. metadata={"help": "Path to the model config"},
  43. )
  44. repo_path: Optional[str] = field(
  45. default=None, metadata={"help": "Path to local repository to store the model and optimizer states"}
  46. )
  47. repo_url: Optional[str] = field(
  48. default=None, metadata={"help": "URL of Hugging Face Hub repository to upload the model and optimizer states"}
  49. )
  50. upload_interval: Optional[float] = field(
  51. default=None, metadata={"help": "Frequency (in seconds) of uploading the model to Hub"}
  52. )
  53. class CheckpointHandler:
  54. def __init__(
  55. self,
  56. monitor_args: TrainingMonitorArguments,
  57. optimizer_args: OptimizerArguments,
  58. averager_args: AveragerArguments,
  59. dht: hivemind.DHT,
  60. ):
  61. self.save_checkpoint_step_interval = monitor_args.save_checkpoint_step_interval
  62. self.repo_path = monitor_args.repo_path
  63. self.repo_url = monitor_args.repo_url
  64. self.upload_interval = monitor_args.upload_interval
  65. self.previous_step = -1
  66. config = AlbertConfig.from_pretrained(monitor_args.model_config_path)
  67. self.model = AlbertForPreTraining(config)
  68. no_decay = ["bias", "LayerNorm.weight"]
  69. optimizer_grouped_parameters = [
  70. {
  71. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  72. "weight_decay": 0.01,
  73. },
  74. {
  75. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  76. "weight_decay": 0.0,
  77. },
  78. ]
  79. opt = Lamb(
  80. optimizer_grouped_parameters,
  81. lr=0.00176,
  82. weight_decay=0.01,
  83. clamp_value=10000.0,
  84. debias=True,
  85. )
  86. self.state_averager = TrainingStateAverager(
  87. dht=dht,
  88. optimizer=opt,
  89. scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
  90. prefix=f"{run_id}_state_averager",
  91. state_compression=hivemind.Float16Compression(),
  92. bandwidth=optimizer_args.bandwidth,
  93. client_mode=optimizer_args.client_mode,
  94. start=True,
  95. **asdict(averager_args),
  96. )
  97. self.previous_timestamp = time.time()
  98. def is_time_to_save_state(self, cur_step):
  99. if self.save_checkpoint_step_interval is None:
  100. return False
  101. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  102. return True
  103. else:
  104. return False
  105. def save_state(self, cur_step):
  106. logger.info("Saving state from peers")
  107. self.state_averager.load_state_from_peers()
  108. self.previous_step = cur_step
  109. def is_time_to_upload(self):
  110. if self.repo_path is None:
  111. return False
  112. elif time.time() - self.previous_timestamp >= self.upload_interval:
  113. return True
  114. else:
  115. return False
  116. def upload_checkpoint(self, current_loss):
  117. logger.info("Saving optimizer")
  118. torch.save(self.state_averager.optimizer.state_dict(), f"{self.repo_path}/optimizer_state.pt")
  119. self.previous_timestamp = time.time()
  120. logger.info("Started uploading to Model Hub")
  121. self.model.push_to_hub(
  122. repo_name=self.repo_path,
  123. repo_url=self.repo_url,
  124. commit_message=f"Step #{current_step}, loss {current_loss:.3f}",
  125. )
  126. logger.info("Finished uploading to Model Hub")
  127. if __name__ == "__main__":
  128. parser = HfArgumentParser((TrainingMonitorArguments, OptimizerArguments, AveragerArguments))
  129. monitor_args, optimizer_args, averager_args = parser.parse_args_into_dataclasses()
  130. if monitor_args.use_google_dns:
  131. request = requests.get("https://api.ipify.org")
  132. request.raise_for_status()
  133. address = request.text
  134. logger.info(f"Received public IP address of this machine: {address}")
  135. version = ip_address(address).version
  136. monitor_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0"]
  137. run_id = monitor_args.run_id
  138. validators, local_public_key = utils.make_validators(run_id)
  139. dht = hivemind.DHT(
  140. start=True,
  141. initial_peers=monitor_args.initial_peers,
  142. record_validators=validators,
  143. use_ipfs=monitor_args.use_ipfs,
  144. host_maddrs=monitor_args.host_maddrs,
  145. announce_maddrs=monitor_args.announce_maddrs,
  146. identity_path=monitor_args.identity_path,
  147. )
  148. log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=monitor_args.use_ipfs)
  149. if monitor_args.wandb_project is not None:
  150. wandb.init(project=monitor_args.wandb_project)
  151. current_step = 0
  152. if monitor_args.store_checkpoints:
  153. checkpoint_handler = CheckpointHandler(monitor_args, optimizer_args, averager_args, dht)
  154. while True:
  155. metrics_dict = dht.get(run_id + "_metrics", latest=True)
  156. if metrics_dict is not None:
  157. metrics_dict = metrics_dict.value
  158. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  159. latest_step = max(item.step for item in metrics)
  160. if latest_step != current_step:
  161. logger.debug(f"Got metrics from {len(metrics)} peers")
  162. for i, metrics_for_peer in enumerate(metrics):
  163. logger.debug(f"{i} peer {metrics_for_peer}")
  164. current_step = latest_step
  165. alive_peers = 0
  166. sum_loss = 0
  167. num_samples = 0
  168. sum_perf = 0
  169. sum_mini_steps = 0
  170. for item in metrics:
  171. sum_loss += item.loss
  172. alive_peers += 1
  173. sum_perf += item.samples_per_second
  174. num_samples += item.samples_accumulated
  175. sum_mini_steps += item.mini_steps
  176. current_loss = sum_loss / sum_mini_steps
  177. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  178. if monitor_args.wandb_project is not None:
  179. wandb.log(
  180. {
  181. "loss": current_loss,
  182. "alive peers": alive_peers,
  183. "samples": num_samples,
  184. "performance": sum_perf,
  185. "step": latest_step,
  186. }
  187. )
  188. if monitor_args.store_checkpoints:
  189. if checkpoint_handler.is_time_to_save_state(current_step):
  190. checkpoint_handler.save_state(current_step)
  191. if checkpoint_handler.is_time_to_upload():
  192. checkpoint_handler.upload_checkpoint(current_loss)
  193. logger.debug("Peer is still alive...")
  194. time.sleep(monitor_args.refresh_period)