run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. #!/usr/bin/env python
  2. import logging
  3. import os
  4. from dataclasses import asdict
  5. from pathlib import Path
  6. from typing import Dict, Any
  7. import torch
  8. import transformers
  9. from datasets import load_from_disk
  10. from torch.utils.data import DataLoader
  11. from torch_optimizer import Lamb
  12. from transformers import set_seed, HfArgumentParser, TrainingArguments, DataCollatorForLanguageModeling
  13. from transformers.models.albert import AlbertTokenizerFast, AlbertConfig, AlbertForPreTraining
  14. from transformers.optimization import get_linear_schedule_with_warmup
  15. from transformers.trainer import Trainer
  16. from transformers.trainer_utils import is_main_process
  17. import hivemind
  18. import utils
  19. from arguments import CollaborationArguments, DatasetArguments, AlbertTrainingArguments, AveragerArguments
  20. logger = logging.getLogger(__name__)
  21. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  22. def setup_logging(training_args):
  23. logging.basicConfig(
  24. format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
  25. datefmt="%m/%d/%Y %H:%M:%S",
  26. level=logging.INFO if is_main_process(training_args.local_rank) else logging.WARN,
  27. )
  28. # Log on each process the small summary:
  29. logger.warning(
  30. f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
  31. + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
  32. )
  33. # Set the verbosity to info of the Transformers logger (on main process only):
  34. if is_main_process(training_args.local_rank):
  35. transformers.utils.logging.set_verbosity_info()
  36. transformers.utils.logging.enable_default_handler()
  37. transformers.utils.logging.enable_explicit_format()
  38. logger.info("Training/evaluation parameters %s", training_args)
  39. def get_model(training_args, config, tokenizer):
  40. # Find latest checkpoint in output_dir
  41. output_dir = Path(training_args.output_dir)
  42. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  43. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  44. if latest_checkpoint_dir is not None:
  45. logger.info(f"Loading model from {latest_checkpoint_dir}")
  46. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  47. else:
  48. logger.info(f"Training from scratch")
  49. model = AlbertForPreTraining(config)
  50. model.resize_token_embeddings(len(tokenizer))
  51. return model
  52. def get_optimizer_and_scheduler(training_args, model):
  53. no_decay = ["bias", "LayerNorm.weight"]
  54. optimizer_grouped_parameters = [
  55. {
  56. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  57. "weight_decay": training_args.weight_decay,
  58. },
  59. {
  60. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  61. "weight_decay": 0.0,
  62. },
  63. ]
  64. opt = Lamb(
  65. optimizer_grouped_parameters,
  66. lr=training_args.learning_rate,
  67. betas=(training_args.adam_beta1, training_args.adam_beta2),
  68. eps=training_args.adam_epsilon,
  69. weight_decay=training_args.weight_decay,
  70. clamp_value=training_args.clamp_value,
  71. debias=True,
  72. )
  73. scheduler = get_linear_schedule_with_warmup(
  74. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  75. )
  76. return opt, scheduler
  77. class CollaborativeCallback(transformers.TrainerCallback):
  78. def __init__(
  79. self,
  80. dht: hivemind.DHT,
  81. optimizer: hivemind.CollaborativeOptimizer,
  82. model: torch.nn.Module,
  83. local_public_key: bytes,
  84. statistics_expiration: float,
  85. ):
  86. super().__init__()
  87. self.model = model
  88. self.dht, self.collaborative_optimizer = dht, optimizer
  89. self.local_public_key = local_public_key
  90. self.statistics_expiration = statistics_expiration
  91. self.last_reported_collaboration_step = -1
  92. self.previous_state = self.get_current_state()
  93. self.samples = 0
  94. self.steps = 0
  95. self.loss = 0
  96. self.total_samples_processed = 0
  97. def on_train_begin(
  98. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  99. ):
  100. logger.info("Loading state from peers")
  101. self.collaborative_optimizer.load_state_from_peers()
  102. def on_step_end(
  103. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  104. ):
  105. control.should_log = True
  106. if not self.params_are_finite():
  107. self.load_from_state(self.previous_state)
  108. return control
  109. self.previous_state = self.get_current_state()
  110. if state.log_history:
  111. self.loss += state.log_history[-1]["loss"]
  112. self.steps += 1
  113. if self.collaborative_optimizer.local_step != self.last_reported_collaboration_step:
  114. self.last_reported_collaboration_step = self.collaborative_optimizer.local_step
  115. self.total_samples_processed += self.samples
  116. samples_per_second = self.collaborative_optimizer.performance_ema.samples_per_second
  117. statistics = utils.LocalMetrics(
  118. step=self.collaborative_optimizer.local_step,
  119. samples_per_second=samples_per_second,
  120. samples_accumulated=self.samples,
  121. loss=self.loss,
  122. mini_steps=self.steps,
  123. )
  124. logger.info(f"Step {self.collaborative_optimizer.local_step}")
  125. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  126. if self.steps:
  127. logger.info(f"Local loss: {self.loss / self.steps}")
  128. self.loss = 0
  129. self.steps = 0
  130. if self.collaborative_optimizer.is_synchronized:
  131. self.dht.store(
  132. key=self.collaborative_optimizer.prefix + "_metrics",
  133. subkey=self.local_public_key,
  134. value=statistics.dict(),
  135. expiration_time=hivemind.get_dht_time() + self.statistics_expiration,
  136. return_future=True,
  137. )
  138. self.samples = self.collaborative_optimizer.local_samples_accumulated
  139. return control
  140. @torch.no_grad()
  141. def get_current_state(self) -> Dict[str, Any]:
  142. return {"model": self.model.state_dict(), "opt": self.collaborative_optimizer.opt.state_dict()}
  143. @torch.no_grad()
  144. def load_from_state(self, state):
  145. self.model.load_state_dict(state["model"])
  146. self.collaborative_optimizer.opt.load_state_dict(state["opt"])
  147. @torch.no_grad()
  148. def params_are_finite(self):
  149. for param in self.model.parameters():
  150. if not torch.all(torch.isfinite(param)):
  151. return False
  152. return True
  153. class NoOpScheduler(LRSchedulerBase):
  154. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in CollaborativeOptimizer.scheduler"""
  155. def get_lr(self):
  156. return [group["lr"] for group in self.optimizer.param_groups]
  157. def print_lr(self, *args, **kwargs):
  158. if self.optimizer.scheduler:
  159. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  160. def step(self):
  161. logger.debug("Called NoOpScheduler.step")
  162. self._last_lr = self.get_lr()
  163. def state_dict(self):
  164. return {}
  165. def load_state_dict(self, *args, **kwargs):
  166. logger.debug("Called NoOpScheduler.load_state_dict")
  167. def main():
  168. parser = HfArgumentParser((AlbertTrainingArguments, DatasetArguments, CollaborationArguments, AveragerArguments))
  169. training_args, dataset_args, collaboration_args, averager_args = parser.parse_args_into_dataclasses()
  170. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  171. if len(collaboration_args.initial_peers) == 0:
  172. raise ValueError("Please specify at least one network endpoint in initial peers.")
  173. setup_logging(training_args)
  174. # Set seed before initializing model.
  175. set_seed(training_args.seed)
  176. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  177. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  178. model = get_model(training_args, config, tokenizer)
  179. model.to(training_args.device)
  180. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  181. # This data collator will take care of randomly masking the tokens.
  182. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  183. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  184. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  185. dht = hivemind.DHT(
  186. start=True,
  187. initial_peers=collaboration_args.initial_peers,
  188. listen=not collaboration_args.client_mode,
  189. record_validators=validators,
  190. use_ipfs=collaboration_args.use_ipfs,
  191. host_maddrs=collaboration_args.host_maddrs,
  192. announce_maddrs=collaboration_args.announce_maddrs,
  193. )
  194. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  195. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  196. if torch.cuda.device_count() != 0:
  197. total_batch_size_per_step *= torch.cuda.device_count()
  198. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  199. collaborative_optimizer = hivemind.CollaborativeOptimizer(
  200. opt=opt,
  201. dht=dht,
  202. scheduler=scheduler,
  203. prefix=collaboration_args.experiment_prefix,
  204. compression_type=hivemind.utils.CompressionType.Value(collaboration_args.compression),
  205. batch_size_per_step=total_batch_size_per_step,
  206. throughput=collaboration_args.bandwidth,
  207. target_batch_size=adjusted_target_batch_size,
  208. client_mode=collaboration_args.client_mode,
  209. verbose=True,
  210. start=True,
  211. **asdict(averager_args),
  212. )
  213. class TrainerWithIndependentShuffling(Trainer):
  214. def get_train_dataloader(self) -> DataLoader:
  215. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  216. torch.manual_seed(hash(local_public_key))
  217. return super().get_train_dataloader()
  218. trainer = TrainerWithIndependentShuffling(
  219. model=model,
  220. args=training_args,
  221. tokenizer=tokenizer,
  222. data_collator=data_collator,
  223. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  224. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  225. optimizers=(collaborative_optimizer, NoOpScheduler(collaborative_optimizer)),
  226. callbacks=[
  227. CollaborativeCallback(
  228. dht, collaborative_optimizer, model, local_public_key, collaboration_args.statistics_expiration
  229. )
  230. ],
  231. )
  232. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  233. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  234. # Training
  235. if training_args.do_train:
  236. latest_checkpoint_dir = max(
  237. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  238. )
  239. trainer.train(model_path=latest_checkpoint_dir)
  240. if __name__ == "__main__":
  241. main()