run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #!/usr/bin/env python
  2. import os
  3. import pickle
  4. from dataclasses import asdict
  5. from pathlib import Path
  6. import torch
  7. import transformers
  8. from datasets import load_from_disk
  9. from torch.utils.data import DataLoader
  10. from torch_optimizer import Lamb
  11. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  12. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  13. from transformers.optimization import get_linear_schedule_with_warmup
  14. from transformers.trainer import Trainer
  15. from transformers.trainer_utils import is_main_process
  16. import hivemind
  17. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  18. import utils
  19. from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
  20. use_hivemind_log_handler("in_root_logger")
  21. logger = get_logger(__name__)
  22. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  23. def setup_transformers_logging(process_rank: int):
  24. if is_main_process(process_rank):
  25. transformers.utils.logging.set_verbosity_info()
  26. transformers.utils.logging.disable_default_handler()
  27. transformers.utils.logging.enable_propagation()
  28. def get_model(training_args, config, tokenizer):
  29. # Find latest checkpoint in output_dir
  30. output_dir = Path(training_args.output_dir)
  31. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  32. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  33. if latest_checkpoint_dir is not None:
  34. logger.info(f"Loading model from {latest_checkpoint_dir}")
  35. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  36. else:
  37. logger.info(f"Training from scratch")
  38. model = AlbertForPreTraining(config)
  39. model.resize_token_embeddings(len(tokenizer))
  40. return model
  41. def get_optimizer_and_scheduler(training_args, model):
  42. no_decay = ["bias", "LayerNorm.weight"]
  43. optimizer_grouped_parameters = [
  44. {
  45. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  46. "weight_decay": training_args.weight_decay,
  47. },
  48. {
  49. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  50. "weight_decay": 0.0,
  51. },
  52. ]
  53. opt = Lamb(
  54. optimizer_grouped_parameters,
  55. lr=training_args.learning_rate,
  56. betas=(training_args.adam_beta1, training_args.adam_beta2),
  57. eps=training_args.adam_epsilon,
  58. weight_decay=training_args.weight_decay,
  59. clamp_value=training_args.clamp_value,
  60. debias=True,
  61. )
  62. scheduler = get_linear_schedule_with_warmup(
  63. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  64. )
  65. return opt, scheduler
  66. class CollaborativeCallback(transformers.TrainerCallback):
  67. """
  68. This callback monitors and reports collaborative training progress.
  69. In case of a catastrophic failure, it can also revert training to a backup.
  70. """
  71. def __init__(
  72. self,
  73. dht: hivemind.DHT,
  74. optimizer: hivemind.CollaborativeOptimizer,
  75. model: torch.nn.Module,
  76. local_public_key: bytes,
  77. statistics_expiration: float,
  78. backup_every_steps: int,
  79. ):
  80. super().__init__()
  81. self.model = model
  82. self.dht, self.collaborative_optimizer = dht, optimizer
  83. self.local_public_key = local_public_key
  84. self.statistics_expiration = statistics_expiration
  85. self.last_reported_collaboration_step = -1
  86. self.samples = 0
  87. self.steps = 0
  88. self.loss = 0
  89. self.total_samples_processed = 0
  90. self.backup_every_steps = backup_every_steps
  91. self.latest_backup = self.backup_state()
  92. def on_train_begin(
  93. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  94. ):
  95. logger.info("Loading state from peers")
  96. self.collaborative_optimizer.load_state_from_peers()
  97. def on_step_end(
  98. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  99. ):
  100. control.should_log = True
  101. if not self.params_are_finite():
  102. self.restore_from_backup(self.latest_backup)
  103. return control
  104. if state.log_history:
  105. self.loss += state.log_history[-1]["loss"]
  106. self.steps += 1
  107. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  108. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  109. self.total_samples_processed += self.samples
  110. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  111. statistics = utils.LocalMetrics(
  112. step=self.collaborative_optimizer.local_step,
  113. samples_per_second=samples_per_second,
  114. samples_accumulated=self.samples,
  115. loss=self.loss,
  116. mini_steps=self.steps,
  117. )
  118. logger.info(f"Step #{self.collaborative_optimizer.local_step}")
  119. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  120. logger.info(f"Performance: {samples_per_second} samples per second.")
  121. if self.steps:
  122. logger.info(f"Local loss: {self.loss / self.steps}")
  123. if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
  124. self.latest_backup = self.backup_state()
  125. self.loss = 0
  126. self.steps = 0
  127. if self.collaborative_optimizer.is_synchronized:
  128. self.dht.store(
  129. key=self.collaborative_optimizer.prefix + "_metrics",
  130. subkey=self.local_public_key,
  131. value=statistics.dict(),
  132. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  133. return_future=True,
  134. )
  135. self.samples = self.collaborative_optimizer.local_samples_accumulated
  136. return control
  137. @torch.no_grad()
  138. def params_are_finite(self):
  139. for param in self.model.parameters():
  140. if not torch.all(torch.isfinite(param)):
  141. return False
  142. return True
  143. @torch.no_grad()
  144. def backup_state(self) -> bytes:
  145. return pickle.dumps(
  146. {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
  147. )
  148. @torch.no_grad()
  149. def restore_from_backup(self, backup: bytes):
  150. state = pickle.loads(backup)
  151. self.model.load_state_dict(state["model"])
  152. self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
  153. class NoOpScheduler(LRSchedulerBase):
  154. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  155. def get_lr(self):
  156. return [group["lr"] for group in self.optimizer.param_groups]
  157. def print_lr(self, *args, **kwargs):
  158. if self.optimizer.scheduler:
  159. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  160. def step(self):
  161. self._last_lr = self.get_lr()
  162. def state_dict(self):
  163. return {}
  164. def load_state_dict(self, *args, **kwargs):
  165. logger.debug("Called NoOpScheduler.load_state_dict")
  166. def main():
  167. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
  168. training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
  169. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  170. if len(collaboration_args.initial_peers) == 0:
  171. raise ValueError("Please specify at least one network endpoint in initial peers.")
  172. setup_transformers_logging(training_args.local_rank)
  173. logger.info(f"Training/evaluation parameters:\n{training_args}")
  174. # Set seed before initializing model.
  175. set_seed(training_args.seed)
  176. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  177. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  178. model = get_model(training_args, config, tokenizer)
  179. model.to(training_args.device)
  180. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  181. # This data collator will take care of randomly masking the tokens.
  182. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  183. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  184. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  185. dht = hivemind.DHT(
  186. start=True,
  187. initial_peers=collaboration_args.initial_peers,
  188. client_mode=collaboration_args.client_mode,
  189. record_validators=validators,
  190. use_ipfs=collaboration_args.use_ipfs,
  191. host_maddrs=collaboration_args.host_maddrs,
  192. announce_maddrs=collaboration_args.announce_maddrs,
  193. identity_path=collaboration_args.identity_path,
  194. )
  195. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  196. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  197. if torch.cuda.device_count() != 0:
  198. total_batch_size_per_step *= torch.cuda.device_count()
  199. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  200. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  201. opt=opt,
  202. dht=dht,
  203. scheduler=scheduler,
  204. prefix=collaboration_args.experiment_prefix,
  205. compression=hivemind.Float16Compression(),
  206. batch_size_per_step=total_batch_size_per_step,
  207. bandwidth=collaboration_args.bandwidth,
  208. target_batch_size=adjusted_target_batch_size,
  209. client_mode=collaboration_args.client_mode,
  210. verbose=True,
  211. start=True,
  212. **asdict(averager_args),
  213. )
  214. class TrainerWithIndependentShuffling(Trainer):
  215. def get_train_dataloader(self) -> DataLoader:
  216. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  217. torch.manual_seed(hash(local_public_key))
  218. return super().get_train_dataloader()
  219. trainer = TrainerWithIndependentShuffling(
  220. model=model,
  221. args=training_args,
  222. tokenizer=tokenizer,
  223. data_collator=data_collator,
  224. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  225. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  226. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  227. callbacks=[
  228. CollaborativeCallback(
  229. dht,
  230. collaborative_optimizer,
  231. model,
  232. local_public_key,
  233. collaboration_args.statistics_expiration,
  234. collaboration_args.backup_every_steps,
  235. )
  236. ],
  237. )
  238. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  239. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  240. # Training
  241. if training_args.do_train:
  242. latest_checkpoint_dir = max(
  243. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  244. )
  245. trainer.train(model_path=latest_checkpoint_dir)
  246. if __name__ == "__main__":
  247. main()