run_training_monitor.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #!/usr/bin/env python
  2. import logging
  3. import time
  4. from dataclasses import asdict, dataclass, field
  5. from ipaddress import ip_address
  6. from typing import Optional
  7. import torch
  8. import wandb
  9. from torch_optimizer import Lamb
  10. from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
  11. from whatsmyip.ip import get_ip
  12. from whatsmyip.providers import GoogleDnsProvider
  13. import hivemind
  14. import utils
  15. from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class CoordinatorArguments(BaseTrainingArguments):
  19. """
  20. Note: You might want to have several initial peers so that if one dies,
  21. new workers still can join the collaboration via alive initial peers' addresses.
  22. Specify initial_peers argument for that purpose
  23. """
  24. use_google_dns: bool = field(
  25. default=False,
  26. metadata={"help":
  27. "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"}
  28. )
  29. refresh_period: float = field(
  30. default=30,
  31. metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
  32. )
  33. wandb_project: Optional[str] = field(
  34. default=None,
  35. metadata={"help": "Learning curves will be published there"}
  36. )
  37. save_checkpoint_step_interval: int = field(
  38. default=5,
  39. metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
  40. )
  41. model_config_path: str = field(
  42. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  43. metadata={"help": "Path to the model config"}
  44. )
  45. repo_path: Optional[str] = field(
  46. default=None,
  47. metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
  48. )
  49. repo_url: Optional[str] = field(
  50. default=None,
  51. metadata={"help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"}
  52. )
  53. upload_interval: Optional[float] = field(
  54. default=None,
  55. metadata={"help": "Coordinator will upload model once in this many seconds"}
  56. )
  57. store_checkpoins: bool = field(
  58. default=False,
  59. metadata={"help": "If True, enables CheckpointHandler"}
  60. )
  61. class CheckpointHandler:
  62. def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args: CollaborativeOptimizerArguments,
  63. averager_args: AveragerArguments, dht: hivemind.DHT):
  64. self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
  65. self.repo_path = coordinator_args.repo_path
  66. self.repo_url = coordinator_args.repo_url
  67. self.upload_interval = coordinator_args.upload_interval
  68. self.previous_step = -1
  69. config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
  70. self.model = AlbertForPreTraining(config)
  71. no_decay = ["bias", "LayerNorm.weight"]
  72. optimizer_grouped_parameters = [
  73. {
  74. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  75. "weight_decay": 0.01,
  76. },
  77. {
  78. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  79. "weight_decay": 0.0,
  80. },
  81. ]
  82. opt = Lamb(
  83. optimizer_grouped_parameters,
  84. lr=0.00176, weight_decay=0.01, clamp_value=10000.0, debias=True,
  85. )
  86. adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
  87. self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
  88. opt=opt, dht=dht, prefix=experiment_prefix,
  89. compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
  90. throughput=collab_optimizer_args.bandwidth,
  91. target_batch_size=adjusted_target_batch_size, client_mode=collab_optimizer_args.client_mode,
  92. verbose=True, start=True, **asdict(averager_args)
  93. )
  94. self.previous_timestamp = time.time()
  95. def is_time_to_save_state(self, cur_step):
  96. if self.save_checkpoint_step_interval is None:
  97. return False
  98. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  99. return True
  100. else:
  101. return False
  102. def save_state(self, cur_step):
  103. logger.info("Saving state from peers")
  104. self.collaborative_optimizer.load_state_from_peers()
  105. self.previous_step = cur_step
  106. def is_time_to_upload(self):
  107. if self.repo_path is None:
  108. return False
  109. elif time.time() - self.previous_timestamp >= self.upload_interval:
  110. return True
  111. else:
  112. return False
  113. def upload_checkpoint(self, current_loss):
  114. logger.info("Saving optimizer")
  115. torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
  116. self.previous_timestamp = time.time()
  117. logger.info('Started uploading model to Hub')
  118. self.model.push_to_hub(repo_name=self.repo_path, repo_url=self.repo_url,
  119. commit_message=f'Step {current_step}, loss {current_loss:.3f}')
  120. logger.info('Finished uploading model to Hub')
  121. if __name__ == '__main__':
  122. parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
  123. coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
  124. if coordinator_args.use_google_dns:
  125. address = get_ip(GoogleDnsProvider)
  126. logger.info(f"Received public IP address of this machine from Google DNS: {address}")
  127. version = ip_address(address).version
  128. coordinator_args.announce_maddrs += [f'/ip{version}/{address}/tcp/0', f'/ip{version}/{address}/udp/0/quic']
  129. experiment_prefix = coordinator_args.experiment_prefix
  130. validators, local_public_key = utils.make_validators(experiment_prefix)
  131. dht = hivemind.DHT(start=True,
  132. initial_peers=coordinator_args.initial_peers,
  133. record_validators=validators,
  134. use_ipfs=coordinator_args.use_ipfs,
  135. host_maddrs=coordinator_args.host_maddrs,
  136. announce_maddrs=coordinator_args.announce_maddrs)
  137. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=coordinator_args.use_ipfs)
  138. if coordinator_args.wandb_project is not None:
  139. wandb.init(project=coordinator_args.wandb_project)
  140. current_step = 0
  141. if coordinator_args.store_checkpoins:
  142. checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
  143. while True:
  144. metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
  145. if metrics_dict is not None:
  146. metrics_dict = metrics_dict.value
  147. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
  148. for peer in metrics_dict]
  149. latest_step = max(item.step for item in metrics)
  150. if latest_step != current_step:
  151. logger.debug(f"Got metrics from {len(metrics)} peers")
  152. for i, metrics_for_peer in enumerate(metrics):
  153. logger.debug(f"{i} peer {metrics_for_peer}")
  154. current_step = latest_step
  155. alive_peers = 0
  156. num_batches = 0
  157. sum_loss = 0
  158. num_samples = 0
  159. sum_perf = 0
  160. sum_mini_steps = 0
  161. for item in metrics:
  162. sum_loss += item.loss
  163. alive_peers += 1
  164. sum_perf += item.samples_per_second
  165. num_samples += item.samples_accumulated
  166. sum_mini_steps += item.mini_steps
  167. current_loss = sum_loss / sum_mini_steps
  168. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  169. if coordinator_args.wandb_project is not None:
  170. wandb.log({
  171. "loss": current_loss,
  172. "alive peers": alive_peers,
  173. "samples": num_samples,
  174. "performance": sum_perf,
  175. "step": latest_step
  176. })
  177. if coordinator_args.store_checkpoins:
  178. if checkpoint_handler.is_time_to_save_state(current_step):
  179. checkpoint_handler.save_state(current_step)
  180. if checkpoint_handler.is_time_to_upload():
  181. checkpoint_handler.upload_checkpoint(current_loss)
  182. logger.debug("Peer is still alive...")
  183. time.sleep(coordinator_args.refresh_period)