run_first_peer.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #!/usr/bin/env python
  2. from dataclasses import dataclass, field, asdict
  3. import subprocess
  4. import time
  5. from typing import Optional
  6. import torch
  7. from torch_optimizer import Lamb
  8. from transformers import AlbertForPreTraining, AlbertConfig, HfArgumentParser
  9. import wandb
  10. from whatsmyip.providers import GoogleDnsProvider
  11. from whatsmyip.ip import get_ip
  12. from arguments import BaseTrainingArguments, CollaborativeOptimizerArguments, AveragerArguments
  13. import hivemind
  14. from hivemind.utils.logging import get_logger
  15. import metrics_utils
  16. logger = get_logger(__name__)
  17. @dataclass
  18. class CoordinatorArguments(BaseTrainingArguments):
  19. """
  20. Note: You might want to have several initial peers so that if one dies,
  21. new workers still can join the collaboration via alive initial peers' addresses.
  22. Specify initial_peers argument for that purpose
  23. """
  24. address: Optional[str] = field(
  25. default=None,
  26. metadata={"help": "This machine's network address. Use public IP for global experiments, "
  27. "local address for private runs"}
  28. )
  29. refresh_period: float = field(
  30. default=30,
  31. metadata={"help": "Coordinator will fetch keys from DHT once in this many seconds"}
  32. )
  33. wandb_project: Optional[str] = field(
  34. default=None,
  35. metadata={"help": "Learning curves will be published there"}
  36. )
  37. save_checkpoint_step_interval: int = field(
  38. default=5,
  39. metadata={"help": "Coordinator will load and save state from peers once every that many steps"}
  40. )
  41. model_config_path: str = field(
  42. default='https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json',
  43. metadata={"help": "Path to the model config"}
  44. )
  45. repo_path: Optional[str] = field(
  46. default=None,
  47. metadata={"help": "Path to HuggingFace repo in which coordinator will upload the model and optimizer states"}
  48. )
  49. upload_interval: Optional[float] = field(
  50. default=None,
  51. metadata={"help": "Coordinator will upload model once in this many seconds"}
  52. )
  53. store_checkpoins: bool = field(
  54. default=False,
  55. metadata={"help": "If True, enables CheckpointHandler"}
  56. )
  57. class CheckpointHandler:
  58. def __init__(self, coordinator_args: CoordinatorArguments, collab_optimizer_args: CollaborativeOptimizerArguments,
  59. averager_args: AveragerArguments, dht: hivemind.DHT):
  60. self.save_checkpoint_step_interval = coordinator_args.save_checkpoint_step_interval
  61. self.repo_path = coordinator_args.repo_path
  62. self.upload_interval = coordinator_args.upload_interval
  63. self.previous_step = -1
  64. config = AlbertConfig.from_pretrained(coordinator_args.model_config_path)
  65. self.model = AlbertForPreTraining(config)
  66. no_decay = ["bias", "LayerNorm.weight"]
  67. optimizer_grouped_parameters = [
  68. {
  69. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  70. "weight_decay": 0.01,
  71. },
  72. {
  73. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  74. "weight_decay": 0.0,
  75. },
  76. ]
  77. opt = Lamb(
  78. optimizer_grouped_parameters,
  79. lr=0.00176, weight_decay=0.01, clamp_value=10000.0, debias=True,
  80. )
  81. adjusted_target_batch_size = collab_optimizer_args.target_batch_size - collab_optimizer_args.batch_size_lead
  82. self.collaborative_optimizer = hivemind.CollaborativeOptimizer(
  83. opt=opt, dht=dht, prefix=experiment_prefix,
  84. compression_type=hivemind.utils.CompressionType.Value(collab_optimizer_args.compression),
  85. throughput=collab_optimizer_args.bandwidth,
  86. target_batch_size=adjusted_target_batch_size, client_mode=collab_optimizer_args.client_mode,
  87. verbose=True, start=True, **asdict(averager_args)
  88. )
  89. self.previous_timestamp = time.time()
  90. def is_time_to_save_state(self, cur_step):
  91. if self.save_checkpoint_step_interval is None:
  92. return False
  93. elif cur_step - self.previous_step >= self.save_checkpoint_step_interval:
  94. return True
  95. else:
  96. return False
  97. def save_state(self, cur_step):
  98. self.collaborative_optimizer.load_state_from_peers()
  99. self.previous_step = cur_step
  100. def is_time_to_upload(self):
  101. if self.repo_path is None:
  102. return False
  103. elif time.time() - self.previous_timestamp >= self.upload_interval:
  104. return True
  105. else:
  106. return False
  107. def upload_checkpoint(self, current_loss):
  108. self.model.save_pretrained(self.repo_path)
  109. torch.save(self.collaborative_optimizer.opt.state_dict(), f"{self.repo_path}/optimizer_state.pt")
  110. self.previous_timestamp = time.time()
  111. try:
  112. subprocess.run("git add --all", shell=True, check=True, cwd=self.repo_path)
  113. current_step = self.collaborative_optimizer.collaboration_state.optimizer_step
  114. subprocess.run(f"git commit -m 'Step {current_step}, loss {current_loss:.3f}'",
  115. shell=True, check=True, cwd=self.repo_path)
  116. subprocess.run("git push", shell=True, check=True, cwd=self.repo_path)
  117. except subprocess.CalledProcessError as e:
  118. logger.warning("Error while uploading model:", e.output)
  119. if __name__ == '__main__':
  120. parser = HfArgumentParser((CoordinatorArguments, CollaborativeOptimizerArguments, AveragerArguments))
  121. coordinator_args, collab_optimizer_args, averager_args = parser.parse_args_into_dataclasses()
  122. if coordinator_args.address is None:
  123. logger.warning("No address specified. Attempting to infer address from DNS.")
  124. coordinator_args.address = get_ip(GoogleDnsProvider)
  125. experiment_prefix = coordinator_args.experiment_prefix
  126. validators, local_public_key = metrics_utils.make_validators(experiment_prefix)
  127. dht = hivemind.DHT(start=True, listen_on=coordinator_args.dht_listen_on,
  128. endpoint=f"{coordinator_args.address}:*", initial_peers=coordinator_args.initial_peers,
  129. record_validators=validators)
  130. logger.info(f"Running DHT root at {coordinator_args.address}:{dht.port}")
  131. if coordinator_args.wandb_project is not None:
  132. wandb.init(project=coordinator_args.wandb_project)
  133. current_step = 0
  134. if coordinator_args.store_checkpoins:
  135. checkpoint_handler = CheckpointHandler(coordinator_args, collab_optimizer_args, averager_args, dht)
  136. while True:
  137. metrics_dict = dht.get(experiment_prefix + '_metrics', latest=True)
  138. if metrics_dict is not None:
  139. metrics_dict = metrics_dict.value
  140. metrics = [metrics_utils.LocalMetrics.parse_obj(metrics_dict[peer].value)
  141. for peer in metrics_dict]
  142. latest_step = max(item.step for item in metrics)
  143. if latest_step != current_step:
  144. logger.debug(f"Got metrics from {len(metrics)} peers")
  145. for i, metrics_for_peer in enumerate(metrics):
  146. logger.debug(f"{i} peer {metrics_for_peer}")
  147. current_step = latest_step
  148. alive_peers = 0
  149. num_batches = 0
  150. sum_loss = 0
  151. num_samples = 0
  152. sum_perf = 0
  153. sum_mini_steps = 0
  154. for item in metrics:
  155. sum_loss += item.loss
  156. alive_peers += 1
  157. sum_perf += item.samples_per_second
  158. num_samples += item.samples_accumulated
  159. sum_mini_steps += item.mini_steps
  160. current_loss = sum_loss / sum_mini_steps
  161. if coordinator_args.wandb_project is not None:
  162. wandb.log({
  163. "loss": current_loss,
  164. "alive peers": alive_peers,
  165. "samples": num_samples,
  166. "performance": sum_perf,
  167. "step": latest_step
  168. })
  169. if coordinator_args.store_checkpoins:
  170. if checkpoint_handler.is_time_to_save_state(current_step):
  171. checkpoint_handler.save_state(current_step)
  172. if checkpoint_handler.is_time_to_upload():
  173. checkpoint_handler.upload_checkpoint(current_loss)
  174. logger.info(f"Step #{current_step}\tloss = {current_loss:.5f}")
  175. logger.debug("Peer is still alive...")
  176. time.sleep(coordinator_args.refresh_period)