run_training_monitor.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. #!/usr/bin/env python
  2. import logging
  3. import time
  4. from dataclasses import asdict, dataclass, field
  5. from ipaddress import ip_address
  6. from typing import Optional
  7. import torch
  8. import wandb
  9. from torch_optimizer import Lamb
  10. from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
  11. from whatsmyip.ip import get_ip
  12. from whatsmyip.providers import GoogleDnsProvider
  13. import hivemind
  14. import utils
  15. from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
  16. logger = logging.getLogger(__name__)
  17. @dataclass
  18. class CoordinatorArguments(BaseTrainingArguments):
  19. """
  20. Note: You might want to have several initial peers so that if one dies,
  21. new workers still can join the collaboration via alive initial peers' addresses.
  22. Specify initial_peers argument for that purpose
  23. """
  24. use_google_dns: bool = field(
  25. default=False,
  26. metadata={
  27. "help": "Use Google DNS to determine the public IP address of this machine (and add it to --announce_maddrs)"
  28. },
  29. )
  30. refresh_period: float = field(
  31. default=30, metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
  32. )
  33. wandb_project: Optional[str] = field(default=None, metadata={"help": "Learning curves will be published there"})
  34. save_checkpoint_step_interval: int = field(
  35. default=5, metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
  36. )
  37. model_config_path: str = field(
  38. default="https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
  39. metadata={"help": "Path to the model config"},
  40. )
  41. repo_path: Optional[str] = field(
  42. default=None,
  43. metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"},
  44. )
  45. repo_url: Optional[str] = field(
  46. default=None,
  47. metadata={
  48. "help": "URL to Hugging Face repository to which the coordinator will upload the model and optimizer states"
  49. },
  50. )
  51. upload_interval: Optional[float] = field(
  52. default=None, metadata={"help": "Coordinator will upload model once in this many seconds"}
  53. )
  54. store_checkpoins: bool = field(default=False, metadata={"help": "If True, enables CheckpointHandler"})
  55. class CheckpointHandler:
  56. def __init__(
  57. self,
  58. coordinator_args: CoordinatorArguments,
  59. collab_optimizer_args: CollaborativeOptimizerArguments,
  60. averager_args: AveragerArguments,
  61. dht: hivemind.DHT,
  62. ):
  63. self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
  64. self.repo_path = coordinator_args.repo_path
  65. self.repo_url = coordinator_args.repo_url
  66. self.upload_interval = coordinator_args.upload_interval
  67. self.previous_step = -1
  68. config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
  69. self.model = AlbertForPreTraining(config)
  70. no_decay = ["bias", "LayerNorm.weight"]
  71. optimizer_grouped_parameters = [
  72. {
  73. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  74. "weight_decay": 0.01,
  75. },
  76. {
  77. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  78. "weight_decay": 0.0,
  79. },
  80. ]
  81. opt = Lamb(
  82. optimizer_grouped_parameters,
  83. lr=0.00176,
  84. weight_decay=0.01,
  85. clamp_value=10000.0,
  86. debias=True,
  87. )
  88. adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
  89. self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
  90. opt=opt,
  91. dht=dht,
  92. prefix=experiment_prefix,
  93. compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
  94. throughput=collab_optimizer_args.bandwidth,
  95. target_batch_size=adjusted_target_batch_size,
  96. client_mode=collab_optimizer_args.client_mode,
  97. verbose=True,
  98. start=True,
  99. **asdict(averager_args),
  100. )
  101. self.previous_timestamp = time.time()
  102. def is_time_to_save_state(self, cur_step):
  103. if self.save_checkpoint_step_interval is None:
  104. return False
  105. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  106. return True
  107. else:
  108. return False
  109. def save_state(self, cur_step):
  110. logger.info("Saving state from peers")
  111. self.collaborative_optimizer.load_state_from_peers()
  112. self.previous_step = cur_step
  113. def is_time_to_upload(self):
  114. if self.repo_path is None:
  115. return False
  116. elif time.time() - self.previous_timestamp >= self.upload_interval:
  117. return True
  118. else:
  119. return False
  120. def upload_checkpoint(self, current_loss):
  121. logger.info("Saving optimizer")
  122. torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
  123. self.previous_timestamp = time.time()
  124. logger.info("Started uploading model to Hub")
  125. self.model.push_to_hub(
  126. repo_name=self.repo_path,
  127. repo_url=self.repo_url,
  128. commit_message=f"Step {current_step}, loss {current_loss:.3f}",
  129. )
  130. logger.info("Finished uploading model to Hub")
  131. if __name__ == "__main__":
  132. parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
  133. coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
  134. if coordinator_args.use_google_dns:
  135. address = get_ip(GoogleDnsProvider)
  136. logger.info(f"Received public IP address of this machine from Google DNS: {address}")
  137. version = ip_address(address).version
  138. coordinator_args.announce_maddrs += [f"/ip{version}/{address}/tcp/0", f"/ip{version}/{address}/udp/0/quic"]
  139. experiment_prefix = coordinator_args.experiment_prefix
  140. validators, local_public_key = utils.make_validators(experiment_prefix)
  141. dht = hivemind.DHT(
  142. start=True,
  143. initial_peers=coordinator_args.initial_peers,
  144. record_validators=validators,
  145. use_ipfs=coordinator_args.use_ipfs,
  146. host_maddrs=coordinator_args.host_maddrs,
  147. announce_maddrs=coordinator_args.announce_maddrs,
  148. )
  149. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=coordinator_args.use_ipfs)
  150. if coordinator_args.wandb_project is not None:
  151. wandb.init(project=coordinator_args.wandb_project)
  152. current_step = 0
  153. if coordinator_args.store_checkpoins:
  154. checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
  155. while True:
  156. metrics_dict = dht.get(experiment_prefix + "_metrics", latest=True)
  157. if metrics_dict is not None:
  158. metrics_dict = metrics_dict.value
  159. metrics = [utils.LocalMetrics.parse_obj(metrics_dict[peer].value) for peer in metrics_dict]
  160. latest_step = max(item.step for item in metrics)
  161. if latest_step != current_step:
  162. logger.debug(f"Got metrics from {len(metrics)} peers")
  163. for i, metrics_for_peer in enumerate(metrics):
  164. logger.debug(f"{i} peer {metrics_for_peer}")
  165. current_step = latest_step
  166. alive_peers = 0
  167. num_batches = 0
  168. sum_loss = 0
  169. num_samples = 0
  170. sum_perf = 0
  171. sum_mini_steps = 0
  172. for item in metrics:
  173. sum_loss += item.loss
  174. alive_peers += 1
  175. sum_perf += item.samples_per_second
  176. num_samples += item.samples_accumulated
  177. sum_mini_steps += item.mini_steps
  178. current_loss = sum_loss / sum_mini_steps
  179. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  180. if coordinator_args.wandb_project is not None:
  181. wandb.log(
  182. {
  183. "loss": current_loss,
  184. "alive peers": alive_peers,
  185. "samples": num_samples,
  186. "performance": sum_perf,
  187. "step": latest_step,
  188. }
  189. )
  190. if coordinator_args.store_checkpoins:
  191. if checkpoint_handler.is_time_to_save_state(current_step):
  192. checkpoint_handler.save_state(current_step)
  193. if checkpoint_handler.is_time_to_upload():
  194. checkpoint_handler.upload_checkpoint(current_loss)
  195. logger.debug("Peer is still alive...")
  196. time.sleep(coordinator_args.refresh_period)