run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. import pickle
  5. from dataclasses import asdict
  6. from pathlib import Path
  7. from typing import Any
  8. import torch
  9. import transformers
  10. from datasets import load_from_disk
  11. from torch.utils.data import DataLoader
  12. from torch_optimizer import Lamb
  13. from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
  14. from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
  15. from transformers.optimization import get_linear_schedule_with_warmup
  16. from transformers.trainer import Trainer
  17. from transformers.trainer_utils import is_main_process
  18. import hivemind
  19. from hivemind.utils.compression import CompressionType
  20. import utils
  21. from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
  22. logger = logging.getLogger(__name__)
  23. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  24. def setup_logging(training_args):
  25. logging.basicConfig(
  26. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  27. datefmt="%m/%d/%Y %H:%M:%S",
  28. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  29. )
  30. # Log on each process the small summary:
  31. logger.warning(
  32. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  33. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  34. )
  35. # Set the verbosity to info of the Transformers logger (on main process only):
  36. if is_main_process(training_args.local_rank):
  37. transformers.utils.logging.set_verbosity_info()
  38. transformers.utils.logging.enable_default_handler()
  39. transformers.utils.logging.enable_explicit_format()
  40. logger.info("Training/evaluation parameters %s", training_args)
  41. def get_model(training_args, config, tokenizer):
  42. # Find latest checkpoint in output_dir
  43. output_dir = Path(training_args.output_dir)
  44. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  45. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  46. if latest_checkpoint_dir is not None:
  47. logger.info(f"Loading model from {latest_checkpoint_dir}")
  48. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  49. else:
  50. logger.info(f"Training from scratch")
  51. model = AlbertForPreTraining(config)
  52. model.resize_token_embeddings(len(tokenizer))
  53. return model
  54. def get_optimizer_and_scheduler(training_args, model):
  55. no_decay = ["bias", "LayerNorm.weight"]
  56. optimizer_grouped_parameters = [
  57. {
  58. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  59. "weight_decay": training_args.weight_decay,
  60. },
  61. {
  62. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  63. "weight_decay": 0.0,
  64. },
  65. ]
  66. opt = Lamb(
  67. optimizer_grouped_parameters,
  68. lr=training_args.learning_rate,
  69. betas=(training_args.adam_beta1, training_args.adam_beta2),
  70. eps=training_args.adam_epsilon,
  71. weight_decay=training_args.weight_decay,
  72. clamp_value=training_args.clamp_value,
  73. debias=True,
  74. )
  75. scheduler = get_linear_schedule_with_warmup(
  76. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  77. )
  78. return opt, scheduler
  79. class CollaborativeCallback(transformers.TrainerCallback):
  80. """
  81. This callback monitors and reports collaborative training progress,
  82. In case of a catastrophic failure, it can also revert training to a backup
  83. """
  84. def __init__(
  85. self,
  86. dht: hivemind.DHT,
  87. optimizer: hivemind.CollaborativeOptimizer,
  88. model: torch.nn.Module,
  89. local_public_key: bytes,
  90. statistics_expiration: float,
  91. backup_every_steps: int,
  92. ):
  93. super().__init__()
  94. self.model = model
  95. self.dht, self.collaborative_optimizer = dht, optimizer
  96. self.local_public_key = local_public_key
  97. self.statistics_expiration = statistics_expiration
  98. self.last_reported_collaboration_step = -1
  99. self.samples = 0
  100. self.steps = 0
  101. self.loss = 0
  102. self.total_samples_processed = 0
  103. self.backup_every_steps = backup_every_steps
  104. self.latest_backup = self.backup_state()
  105. def on_train_begin(
  106. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  107. ):
  108. logger.info("Loading state from peers")
  109. self.collaborative_optimizer.load_state_from_peers()
  110. def on_step_end(
  111. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  112. ):
  113. control.should_log = True
  114. if not self.params_are_finite():
  115. self.restore_from_backup(self.latest_backup)
  116. return control
  117. if state.log_history:
  118. self.loss += state.log_history[-1]["loss"]
  119. self.steps += 1
  120. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  121. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  122. self.total_samples_processed += self.samples
  123. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  124. statistics = utils.LocalMetrics(
  125. step=self.collaborative_optimizer.local_step,
  126. samples_per_second=samples_per_second,
  127. samples_accumulated=self.samples,
  128. loss=self.loss,
  129. mini_steps=self.steps,
  130. )
  131. logger.info(f"Step {self.collaborative_optimizer.local_step}")
  132. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  133. if self.steps:
  134. logger.info(f"Local loss: {self.loss / self.steps}")
  135. if self.collaborative_optimizer.local_step % self.backup_every_steps == 0:
  136. self.latest_backup = self.backup_state()
  137. self.loss = 0
  138. self.steps = 0
  139. if self.collaborative_optimizer.is_synchronized:
  140. self.dht.store(
  141. key=self.collaborative_optimizer.prefix + "_metrics",
  142. subkey=self.local_public_key,
  143. value=statistics.dict(),
  144. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  145. return_future=True,
  146. )
  147. self.samples = self.collaborative_optimizer.local_samples_accumulated
  148. return control
  149. @torch.no_grad()
  150. def params_are_finite(self):
  151. for param in self.model.parameters():
  152. if not torch.all(torch.isfinite(param)):
  153. return False
  154. return True
  155. @torch.no_grad()
  156. def backup_state(self) -> Any:
  157. return pickle.dumps(
  158. {"model": self.model.state_dict(), "training": self.collaborative_optimizer.opt.state_dict()}
  159. )
  160. @torch.no_grad()
  161. def restore_from_backup(self, backup):
  162. state = pickle.loads(backup)
  163. self.model.load_state_dict(state["model"])
  164. self.collaborative_optimizer.opt.load_state_dict(state["training"])
  165. class NoOpScheduler(LRSchedulerBase):
  166. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  167. def get_lr(self):
  168. return [group["lr"] for group in self.optimizer.param_groups]
  169. def print_lr(self, *args, **kwargs):
  170. if self.optimizer.scheduler:
  171. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  172. def step(self):
  173. logger.debug("Called NoOpScheduler.step")
  174. self._last_lr = self.get_lr()
  175. def state_dict(self):
  176. return {}
  177. def load_state_dict(self, *args, **kwargs):
  178. logger.debug("Called NoOpScheduler.load_state_dict")
  179. def main():
  180. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
  181. training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
  182. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  183. if len(collaboration_args.initial_peers) == 0:
  184. raise ValueError("Please specify at least one network endpoint in initial peers.")
  185. setup_logging(training_args)
  186. # Set seed before initializing model.
  187. set_seed(training_args.seed)
  188. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  189. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  190. model = get_model(training_args, config, tokenizer)
  191. model.to(training_args.device)
  192. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  193. # This data collator will take care of randomly masking the tokens.
  194. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  195. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  196. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  197. dht = hivemind.DHT(
  198. start=True,
  199. initial_peers=collaboration_args.initial_peers,
  200. client_mode=collaboration_args.client_mode,
  201. record_validators=validators,
  202. use_ipfs=collaboration_args.use_ipfs,
  203. host_maddrs=collaboration_args.host_maddrs,
  204. announce_maddrs=collaboration_args.announce_maddrs,
  205. )
  206. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  207. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  208. if torch.cuda.device_count() != 0:
  209. total_batch_size_per_step *= torch.cuda.device_count()
  210. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  211. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  212. opt=opt,
  213. dht=dht,
  214. scheduler=scheduler,
  215. prefix=collaboration_args.experiment_prefix,
  216. compression_type=CompressionType.Value(collaboration_args.compression),
  217. batch_size_per_step=total_batch_size_per_step,
  218. bandwidth=collaboration_args.bandwidth,
  219. target_batch_size=adjusted_target_batch_size,
  220. client_mode=collaboration_args.client_mode,
  221. verbose=True,
  222. start=True,
  223. **asdict(averager_args),
  224. )
  225. class TrainerWithIndependentShuffling(Trainer):
  226. def get_train_dataloader(self) -> DataLoader:
  227. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  228. torch.manual_seed(hash(local_public_key))
  229. return super().get_train_dataloader()
  230. trainer = TrainerWithIndependentShuffling(
  231. model=model,
  232. args=training_args,
  233. tokenizer=tokenizer,
  234. data_collator=data_collator,
  235. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  236. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  237. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  238. callbacks=[
  239. CollaborativeCallback(
  240. dht,
  241. collaborative_optimizer,
  242. model,
  243. local_public_key,
  244. collaboration_args.statistics_expiration,
  245. collaboration_args.backup_every_steps,
  246. )
  247. ],
  248. )
  249. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  250. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  251. # Training
  252. if training_args.do_train:
  253. latest_checkpoint_dir = max(
  254. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  255. )
  256. trainer.train(model_path=latest_checkpoint_dir)
  257. if __name__ == "__main__":
  258. main()