run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #!/usr/bin/env python
  2. import os
  3. import pickle
  4. from dataclasses import asdict
  5. from pathlib import Path
  6. import torch
  7. import transformers
  8. from datasets import load_from_disk
  9. from torch.utils.data import DataLoader
  10. from torch_optimizer import Lamb
  11. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  12. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  13. from transformers.optimization import get_linear_schedule_with_warmup
  14. from transformers.trainer import Trainer
  15. from transformers.trainer_utils import is_main_process
  16. from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
  17. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  18. import utils
  19. from arguments import (
  20. AlbertTrainingArguments,
  21. AveragerArguments,
  22. CollaborationArguments,
  23. DatasetArguments,
  24. ProgressTrackerArguments,
  25. )
  26. use_hivemind_log_handler("in_root_logger")
  27. logger = get_logger(__name__)
  28. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  29. def setup_transformers_logging(process_rank: int):
  30. if is_main_process(process_rank):
  31. transformers.utils.logging.set_verbosity_info()
  32. transformers.utils.logging.disable_default_handler()
  33. transformers.utils.logging.enable_propagation()
  34. def get_model(training_args, config, tokenizer):
  35. # Find latest checkpoint in output_dir
  36. output_dir = Path(training_args.output_dir)
  37. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  38. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  39. if latest_checkpoint_dir is not None:
  40. logger.info(f"Loading model from {latest_checkpoint_dir}")
  41. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  42. else:
  43. logger.info(f"Training from scratch")
  44. model = AlbertForPreTraining(config)
  45. model.resize_token_embeddings(len(tokenizer))
  46. return model
  47. def get_optimizer_and_scheduler(training_args, model):
  48. no_decay = ["bias", "LayerNorm.weight"]
  49. optimizer_grouped_parameters = [
  50. {
  51. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  52. "weight_decay": training_args.weight_decay,
  53. },
  54. {
  55. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  56. "weight_decay": 0.0,
  57. },
  58. ]
  59. opt = Lamb(
  60. optimizer_grouped_parameters,
  61. lr=training_args.learning_rate,
  62. betas=(training_args.adam_beta1, training_args.adam_beta2),
  63. eps=training_args.adam_epsilon,
  64. weight_decay=training_args.weight_decay,
  65. clamp_value=training_args.clamp_value,
  66. debias=True,
  67. )
  68. scheduler = get_linear_schedule_with_warmup(
  69. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  70. )
  71. return opt, scheduler
  72. class CollaborativeCallback(transformers.TrainerCallback):
  73. """
  74. This callback monitors and reports collaborative training progress.
  75. In case of a catastrophic failure, it can also revert training to a backup.
  76. """
  77. def __init__(
  78. self,
  79. dht: DHT,
  80. optimizer: Optimizer,
  81. model: torch.nn.Module,
  82. local_public_key: bytes,
  83. statistics_expiration: float,
  84. backup_every_steps: int,
  85. ):
  86. super().__init__()
  87. self.model = model
  88. self.dht, self.optimizer = dht, optimizer
  89. self.local_public_key = local_public_key
  90. self.statistics_expiration = statistics_expiration
  91. self.last_reported_collaboration_step = -1
  92. self.samples = 0
  93. self.steps = 0
  94. self.loss = 0
  95. self.total_samples_processed = 0
  96. self.backup_every_steps = backup_every_steps
  97. self.latest_backup = self.backup_state()
  98. def on_train_begin(
  99. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  100. ):
  101. logger.info("Loading state from peers")
  102. self.optimizer.load_state_from_peers()
  103. def on_step_end(
  104. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  105. ):
  106. control.should_log = True
  107. if not self.params_are_finite():
  108. self.restore_from_backup(self.latest_backup)
  109. return control
  110. local_progress = self.optimizer.local_progress
  111. if state.log_history:
  112. self.loss += state.log_history[-1]["loss"]
  113. self.steps += 1
  114. if self.optimizer.local_epoch != self.last_reported_collaboration_step:
  115. self.last_reported_collaboration_step = self.optimizer.local_epoch
  116. self.total_samples_processed += self.samples
  117. samples_per_second = local_progress.samples_per_second
  118. statistics = utils.LocalMetrics(
  119. step=self.optimizer.local_epoch,
  120. samples_per_second=samples_per_second,
  121. samples_accumulated=self.samples,
  122. loss=self.loss,
  123. mini_steps=self.steps,
  124. )
  125. logger.info(f"Step #{self.optimizer.local_epoch}")
  126. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  127. logger.info(f"Performance: {samples_per_second} samples per second.")
  128. if self.steps:
  129. logger.info(f"Local loss: {self.loss / self.steps}")
  130. if self.optimizer.local_epoch % self.backup_every_steps == 0:
  131. self.latest_backup = self.backup_state()
  132. self.loss = 0
  133. self.steps = 0
  134. if self.optimizer.is_synchronized_with_peers():
  135. self.dht.store(
  136. key=self.optimizer.run_id + "_metrics",
  137. subkey=self.local_public_key,
  138. value=statistics.dict(),
  139. expiration_time=get_dht_time() + self.statistics_expiration,
  140. return_future=True,
  141. )
  142. self.samples = local_progress.samples_accumulated
  143. return control
  144. @torch.no_grad()
  145. def params_are_finite(self):
  146. for param in self.model.parameters():
  147. if not torch.all(torch.isfinite(param)):
  148. return False
  149. return True
  150. @torch.no_grad()
  151. def backup_state(self) -> bytes:
  152. return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
  153. @torch.no_grad()
  154. def restore_from_backup(self, backup: bytes):
  155. state = pickle.loads(backup)
  156. self.model.load_state_dict(state["model"])
  157. self.optimizer.load_state_dict(state["optimizer"])
  158. class NoOpScheduler(LRSchedulerBase):
  159. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
  160. def get_lr(self):
  161. return [group["lr"] for group in self.optimizer.param_groups]
  162. def print_lr(self, *args, **kwargs):
  163. if self.optimizer.scheduler:
  164. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  165. def step(self):
  166. self._last_lr = self.get_lr()
  167. def state_dict(self):
  168. return {}
  169. def load_state_dict(self, *args, **kwargs):
  170. logger.debug("Called NoOpScheduler.load_state_dict")
  171. def main():
  172. parser = HfArgumentParser(
  173. (
  174. AlbertTrainingArguments,
  175. DatasetArguments,
  176. CollaborationArguments,
  177. AveragerArguments,
  178. ProgressTrackerArguments,
  179. )
  180. )
  181. training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
  182. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  183. if len(collaboration_args.initial_peers) == 0:
  184. raise ValueError("Please specify at least one network endpoint in initial peers.")
  185. setup_transformers_logging(training_args.local_rank)
  186. logger.info(f"Training/evaluation parameters:\n{training_args}")
  187. # Set seed before initializing model.
  188. set_seed(training_args.seed)
  189. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  190. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  191. model = get_model(training_args, config, tokenizer)
  192. model.to(training_args.device)
  193. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  194. # This data collator will take care of randomly masking the tokens.
  195. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  196. opt, scheduler = get_optimizer_and_scheduler(training_args, model)
  197. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  198. dht = DHT(
  199. start=True,
  200. initial_peers=collaboration_args.initial_peers,
  201. client_mode=collaboration_args.client_mode,
  202. record_validators=validators,
  203. use_ipfs=collaboration_args.use_ipfs,
  204. host_maddrs=collaboration_args.host_maddrs,
  205. announce_maddrs=collaboration_args.announce_maddrs,
  206. identity_path=collaboration_args.identity_path,
  207. )
  208. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  209. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  210. if torch.cuda.device_count() != 0:
  211. total_batch_size_per_step *= torch.cuda.device_count()
  212. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  213. optimizer = Optimizer(
  214. dht=dht,
  215. run_id=collaboration_args.experiment_prefix,
  216. target_batch_size=adjusted_target_batch_size,
  217. batch_size_per_step=total_batch_size_per_step,
  218. optimizer=opt,
  219. scheduler=scheduler,
  220. matchmaking_time=collaboration_args.matchmaking_time,
  221. averaging_timeout=collaboration_args.averaging_timeout,
  222. offload_optimizer=True,
  223. delay_optimizer_step=True,
  224. delay_grad_averaging=True,
  225. client_mode=collaboration_args.client_mode,
  226. grad_compression=Float16Compression(),
  227. state_averaging_compression=Float16Compression(),
  228. averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
  229. tracker_opts=asdict(tracker_args),
  230. verbose=True,
  231. )
  232. class TrainerWithIndependentShuffling(Trainer):
  233. def get_train_dataloader(self) -> DataLoader:
  234. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  235. torch.manual_seed(hash(local_public_key))
  236. return super().get_train_dataloader()
  237. trainer = TrainerWithIndependentShuffling(
  238. model=model,
  239. args=training_args,
  240. tokenizer=tokenizer,
  241. data_collator=data_collator,
  242. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  243. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  244. optimizers=(optimizer, NoOpScheduler(optimizer)),
  245. callbacks=[
  246. CollaborativeCallback(
  247. dht,
  248. optimizer,
  249. model,
  250. local_public_key,
  251. collaboration_args.statistics_expiration,
  252. collaboration_args.backup_every_steps,
  253. )
  254. ],
  255. )
  256. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  257. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  258. # Training
  259. if training_args.do_train:
  260. latest_checkpoint_dir = max(
  261. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  262. )
  263. trainer.train(model_path=latest_checkpoint_dir)
  264. if __name__ == "__main__":
  265. main()