run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. import pickle
  5. from dataclasses import asdict
  6. from pathlib import Path
  7. import torch
  8. import transformers
  9. from datasets import load_from_disk
  10. from torch.utils.data import DataLoader
  11. from torch_optimizer import Lamb
  12. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  13. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  14. from transformers.optimization import get_linear_schedule_with_warmup
  15. from transformers.trainer import Trainer
  16. from transformers.trainer_utils import is_main_process
  17. import hivemind
  18. import utils
  19. from arguments import AlbertTrainingArguments, AveragerArguments, CollaborationArguments, DatasetArguments
  20. logger = logging.getLogger(__name__)
  21. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  22. def setup_logging(training_args):
  23. logging.basicConfig(
  24. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  25. datefmt="%m/%d/%Y %H:%M:%S",
  26. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  27. )
  28. # Log on each process the small summary:
  29. logger.warning(
  30. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  31. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  32. )
  33. # Set the verbosity to info of the Transformers logger (on main process only):
  34. if is_main_process(training_args.local_rank):
  35. transformers.utils.logging.set_verbosity_info()
  36. transformers.utils.logging.enable_default_handler()
  37. transformers.utils.logging.enable_explicit_format()
  38. logger.info("Training/evaluation parameters %s", training_args)
  39. def get_model(training_args, config, tokenizer):
  40. # Find latest checkpoint in output_dir
  41. output_dir = Path(training_args.output_dir)
  42. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  43. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  44. if latest_checkpoint_dir is not None:
  45. logger.info(f"Loading model from {latest_checkpoint_dir}")
  46. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  47. else:
  48. logger.info(f"Training from scratch")
  49. model = AlbertForPreTraining(config)
  50. model.resize_token_embeddings(len(tokenizer))
  51. return model
  52. def get_optimizer_and_scheduler(training_args, model):
  53. no_decay = ["bias", "LayerNorm.weight"]
  54. optimizer_grouped_parameters = [
  55. {
  56. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  57. "weight_decay": training_args.weight_decay,
  58. },
  59. {
  60. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  61. "weight_decay": 0.0,
  62. },
  63. ]
  64. opt = Lamb(
  65. optimizer_grouped_parameters,
  66. lr=training_args.learning_rate,
  67. betas=(training_args.adam_beta1, training_args.adam_beta2),
  68. eps=training_args.adam_epsilon,
  69. weight_decay=training_args.weight_decay,
  70. clamp_value=training_args.clamp_value,
  71. debias=True,
  72. )
  73. scheduler = get_linear_schedule_with_warmup(
  74. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  75. )
  76. return opt, scheduler
  77. class CollaborativeCallback(transformers.TrainerCallback):
  78. """
  79. This callback monitors and reports collaborative training progress.
  80. In case of a catastrophic failure, it can also revert training to a backup.
  81. """
  82. def __init__(
  83. self,
  84. dht: hivemind.DHT,
  85. optimizer: hivemind.CollaborativeOptimizer,
  86. model: torch.nn.Module,
  87. local_public_key: bytes,
  88. statistics_expiration: float,
  89. backup_every_steps: int,
  90. ):
  91. super().__init__()
  92. self.model = model
  93. self.dht, self.collaborative_optimizer = dht, optimizer
  94. self.local_public_key = local_public_key
  95. self.statistics_expiration = statistics_expiration
  96. self.last_reported_collaboration_step = -1
  97. self.samples = 0
  98. self.steps = 0
  99. self.loss = 0
  100. self.total_samples_processed = 0
  101. self.backup_every_steps = backup_every_steps
  102. self.latest_backup = self.backup_state()
  103. def on_train_begin(
  104. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  105. ):
  106. logger.info("Loading state from peers")
  107. self.collaborative_optimizer.load_state_from_peers()
  108. def on_step_end(
  109. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  110. ):
  111. control.should_log = True
  112. if not self.params_are_finite():
  113. self.restore_from_backup(self.latest_backup)
  114. return control
  115. if state.log_history:
  116. self.loss += state.log_history[-1]["loss"]
  117. self.steps += 1
  118. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  119. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  120. self.total_samples_processed += self.samples
  121. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  122. statistics = utils.LocalMetrics(
  123. step=self.collaborative_optimizer.local_step,
  124. samples_per_second=samples_per_second,
  125. samples_accumulated=self.samples,
  126. loss=self.loss,
  127. mini_steps=self.steps,
  128. )
  129. logger.info(f"Step {self.collaborative_optimizer.local_step}")
  130. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  131. logger.info(f"Performance: {samples_per_second} samples per second.")
  132. if self.steps:
  133. logger.info(f"Local loss: {self.loss / self.steps}")
  134. if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
  135. self.latest_backup = self.backup_state()
  136. self.loss = 0
  137. self.steps = 0
  138. if self.collaborative_optimizer.is_synchronized:
  139. self.dht.store(
  140. key=self.collaborative_optimizer.prefix + "_metrics",
  141. subkey=self.local_public_key,
  142. value=statistics.dict(),
  143. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  144. return_future=True,
  145. )
  146. self.samples = self.collaborative_optimizer.local_samples_accumulated
  147. return control
  148. @torch.no_grad()
  149. def params_are_finite(self):
  150. for param in self.model.parameters():
  151. if not torch.all(torch.isfinite(param)):
  152. return False
  153. return True
  154. @torch.no_grad()
  155. def backup_state(self) -> bytes:
  156. return pickle.dumps(
  157. {"model": self.model.state_dict(), "optimizer": self.collaborative_optimizer.opt.state_dict()}
  158. )
  159. @torch.no_grad()
  160. def restore_from_backup(self, backup: bytes):
  161. state = pickle.loads(backup)
  162. self.model.load_state_dict(state["model"])
  163. self.collaborative_optimizer.opt.load_state_dict(state["optimizer"])
  164. class NoOpScheduler(LRSchedulerBase):
  165. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  166. def get_lr(self):
  167. return [group["lr"] for group in self.optimizer.param_groups]
  168. def print_lr(self, *args, **kwargs):
  169. if self.optimizer.scheduler:
  170. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  171. def step(self):
  172. self._last_lr = self.get_lr()
  173. def state_dict(self):
  174. return {}
  175. def load_state_dict(self, *args, **kwargs):
  176. logger.debug("Called NoOpScheduler.load_state_dict")
  177. def main():
  178. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
  179. training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
  180. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  181. if len(collaboration_args.initial_peers) == 0:
  182. raise ValueError("Please specify at least one network endpoint in initial peers.")
  183. setup_logging(training_args)
  184. # Set seed before initializing model.
  185. set_seed(training_args.seed)
  186. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  187. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  188. model = get_model(training_args, config, tokenizer)
  189. model.to(training_args.device)
  190. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  191. # This data collator will take care of randomly masking the tokens.
  192. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  193. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  194. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  195. dht = hivemind.DHT(
  196. start=True,
  197. initial_peers=collaboration_args.initial_peers,
  198. client_mode=collaboration_args.client_mode,
  199. record_validators=validators,
  200. use_ipfs=collaboration_args.use_ipfs,
  201. host_maddrs=collaboration_args.host_maddrs,
  202. announce_maddrs=collaboration_args.announce_maddrs,
  203. identity_path=collaboration_args.identity_path,
  204. )
  205. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  206. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  207. if torch.cuda.device_count() != 0:
  208. total_batch_size_per_step *= torch.cuda.device_count()
  209. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  210. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  211. opt=opt,
  212. dht=dht,
  213. scheduler=scheduler,
  214. prefix=collaboration_args.experiment_prefix,
  215. compression=hivemind.Float16Compression(),
  216. batch_size_per_step=total_batch_size_per_step,
  217. bandwidth=collaboration_args.bandwidth,
  218. target_batch_size=adjusted_target_batch_size,
  219. client_mode=collaboration_args.client_mode,
  220. verbose=True,
  221. start=True,
  222. **asdict(averager_args),
  223. )
  224. class TrainerWithIndependentShuffling(Trainer):
  225. def get_train_dataloader(self) -> DataLoader:
  226. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  227. torch.manual_seed(hash(local_public_key))
  228. return super().get_train_dataloader()
  229. trainer = TrainerWithIndependentShuffling(
  230. model=model,
  231. args=training_args,
  232. tokenizer=tokenizer,
  233. data_collator=data_collator,
  234. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  235. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  236. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  237. callbacks=[
  238. CollaborativeCallback(
  239. dht,
  240. collaborative_optimizer,
  241. model,
  242. local_public_key,
  243. collaboration_args.statistics_expiration,
  244. collaboration_args.backup_every_steps,
  245. )
  246. ],
  247. )
  248. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  249. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  250. # Training
  251. if training_args.do_train:
  252. latest_checkpoint_dir = max(
  253. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  254. )
  255. trainer.train(model_path=latest_checkpoint_dir)
  256. if __name__ == "__main__":
  257. main()