run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. #!/usr/bin/env python3
  2. import os
  3. import pickle
  4. import sys
  5. from dataclasses import asdict
  6. from pathlib import Path
  7. import torch
  8. import transformers
  9. from datasets import load_from_disk
  10. from torch.utils.data import DataLoader
  11. from torch_optimizer import Lamb
  12. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  13. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  14. from transformers.optimization import get_linear_schedule_with_warmup
  15. from transformers.trainer import Trainer
  16. from transformers.trainer_utils import is_main_process
  17. from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
  18. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  19. import utils
  20. from arguments import (
  21. AlbertTrainingArguments,
  22. AveragerArguments,
  23. CollaborationArguments,
  24. DatasetArguments,
  25. ProgressTrackerArguments,
  26. )
  27. use_hivemind_log_handler("in_root_logger")
  28. logger = get_logger(__name__)
  29. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  30. def setup_transformers_logging(process_rank: int):
  31. if is_main_process(process_rank):
  32. transformers.utils.logging.set_verbosity_info()
  33. transformers.utils.logging.disable_default_handler()
  34. transformers.utils.logging.enable_propagation()
  35. def get_model(training_args, config, tokenizer):
  36. # Find latest checkpoint in output_dir
  37. output_dir = Path(training_args.output_dir)
  38. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  39. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  40. if latest_checkpoint_dir is not None:
  41. logger.info(f"Loading model from {latest_checkpoint_dir}")
  42. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  43. else:
  44. logger.info(f"Training from scratch")
  45. model = AlbertForPreTraining(config)
  46. model.resize_token_embeddings(len(tokenizer))
  47. return model
  48. class CollaborativeCallback(transformers.TrainerCallback):
  49. """
  50. This callback monitors and reports collaborative training progress.
  51. In case of a catastrophic failure, it can also revert training to a backup.
  52. """
  53. def __init__(
  54. self,
  55. dht: DHT,
  56. optimizer: Optimizer,
  57. model: torch.nn.Module,
  58. local_public_key: bytes,
  59. statistics_expiration: float,
  60. backup_every_steps: int,
  61. ):
  62. super().__init__()
  63. self.model = model
  64. self.dht, self.optimizer = dht, optimizer
  65. self.local_public_key = local_public_key
  66. self.statistics_expiration = statistics_expiration
  67. self.last_reported_collaboration_step = -1
  68. self.samples = 0
  69. self.steps = 0
  70. self.loss = 0
  71. self.total_samples_processed = 0
  72. self.backup_every_steps = backup_every_steps
  73. self.latest_backup = self.backup_state()
  74. def on_train_begin(
  75. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  76. ):
  77. logger.info("Loading state from peers")
  78. self.optimizer.load_state_from_peers()
  79. def on_step_end(
  80. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  81. ):
  82. control.should_log = True
  83. if not self.params_are_finite():
  84. self.restore_from_backup(self.latest_backup)
  85. return control
  86. local_progress = self.optimizer.local_progress
  87. if state.log_history:
  88. self.loss += state.log_history[-1]["loss"]
  89. self.steps += 1
  90. if self.optimizer.local_epoch != self.last_reported_collaboration_step:
  91. self.last_reported_collaboration_step = self.optimizer.local_epoch
  92. self.total_samples_processed += self.samples
  93. samples_per_second = local_progress.samples_per_second
  94. statistics = utils.LocalMetrics(
  95. step=self.optimizer.local_epoch,
  96. samples_per_second=samples_per_second,
  97. samples_accumulated=self.samples,
  98. loss=self.loss,
  99. mini_steps=self.steps,
  100. )
  101. logger.info(f"Step #{self.optimizer.local_epoch}")
  102. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  103. logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
  104. if self.steps:
  105. logger.info(f"Local loss: {self.loss / self.steps:.5f}")
  106. if self.optimizer.local_epoch % self.backup_every_steps == 0:
  107. self.latest_backup = self.backup_state()
  108. self.loss = 0
  109. self.steps = 0
  110. if self.optimizer.is_synchronized_with_peers():
  111. self.dht.store(
  112. key=self.optimizer.run_id + "_metrics",
  113. subkey=self.local_public_key,
  114. value=statistics.dict(),
  115. expiration_time=get_dht_time() + self.statistics_expiration,
  116. return_future=True,
  117. )
  118. self.samples = local_progress.samples_accumulated
  119. return control
  120. @torch.no_grad()
  121. def params_are_finite(self):
  122. for param in self.model.parameters():
  123. if not torch.all(torch.isfinite(param)):
  124. return False
  125. return True
  126. @torch.no_grad()
  127. def backup_state(self) -> bytes:
  128. return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
  129. @torch.no_grad()
  130. def restore_from_backup(self, backup: bytes):
  131. state = pickle.loads(backup)
  132. self.model.load_state_dict(state["model"])
  133. self.optimizer.load_state_dict(state["optimizer"])
  134. class NoOpScheduler(LRSchedulerBase):
  135. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
  136. def get_lr(self):
  137. return [group["lr"] for group in self.optimizer.param_groups]
  138. def print_lr(self, *args, **kwargs):
  139. if self.optimizer.scheduler:
  140. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  141. def step(self):
  142. self._last_lr = self.get_lr()
  143. def state_dict(self):
  144. return {}
  145. def load_state_dict(self, *args, **kwargs):
  146. logger.debug("Called NoOpScheduler.load_state_dict")
  147. def main():
  148. parser = HfArgumentParser(
  149. (
  150. AlbertTrainingArguments,
  151. DatasetArguments,
  152. CollaborationArguments,
  153. AveragerArguments,
  154. ProgressTrackerArguments,
  155. )
  156. )
  157. training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
  158. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  159. setup_transformers_logging(training_args.local_rank)
  160. logger.info(f"Training/evaluation parameters:\n{training_args}")
  161. # Set seed before initializing model.
  162. set_seed(training_args.seed)
  163. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  164. try:
  165. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  166. except OSError:
  167. logger.fatal(
  168. f"No tokenizer data found in {dataset_args.tokenizer_path}, "
  169. f"please run ./tokenize_wikitext103.py before running this"
  170. )
  171. sys.exit(1)
  172. model = get_model(training_args, config, tokenizer)
  173. model.to(training_args.device)
  174. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  175. # This data collator will take care of randomly masking the tokens.
  176. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  177. validators, local_public_key = utils.make_validators(collaboration_args.experiment_prefix)
  178. dht = DHT(
  179. start=True,
  180. initial_peers=collaboration_args.initial_peers,
  181. client_mode=collaboration_args.client_mode,
  182. record_validators=validators,
  183. use_ipfs=collaboration_args.use_ipfs,
  184. host_maddrs=collaboration_args.host_maddrs,
  185. announce_maddrs=collaboration_args.announce_maddrs,
  186. identity_path=collaboration_args.identity_path,
  187. )
  188. utils.log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  189. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  190. if torch.cuda.device_count() != 0:
  191. total_batch_size_per_step *= torch.cuda.device_count()
  192. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  193. # We need to make such a lambda function instead of just an optimizer instance
  194. # to make hivemind.Optimizer(..., offload_optimizer=True) work
  195. opt = lambda params: Lamb(
  196. params,
  197. lr=training_args.learning_rate,
  198. betas=(training_args.adam_beta1, training_args.adam_beta2),
  199. eps=training_args.adam_epsilon,
  200. weight_decay=training_args.weight_decay,
  201. clamp_value=training_args.clamp_value,
  202. debias=True,
  203. )
  204. no_decay = ["bias", "LayerNorm.weight"]
  205. params = [
  206. {
  207. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  208. "weight_decay": training_args.weight_decay,
  209. },
  210. {
  211. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  212. "weight_decay": 0.0,
  213. },
  214. ]
  215. scheduler = lambda opt: get_linear_schedule_with_warmup(
  216. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
  217. )
  218. optimizer = Optimizer(
  219. dht=dht,
  220. run_id=collaboration_args.experiment_prefix,
  221. target_batch_size=adjusted_target_batch_size,
  222. batch_size_per_step=total_batch_size_per_step,
  223. optimizer=opt,
  224. params=params,
  225. scheduler=scheduler,
  226. matchmaking_time=collaboration_args.matchmaking_time,
  227. averaging_timeout=collaboration_args.averaging_timeout,
  228. offload_optimizer=True,
  229. delay_optimizer_step=True,
  230. delay_grad_averaging=True,
  231. client_mode=collaboration_args.client_mode,
  232. grad_compression=Float16Compression(),
  233. state_averaging_compression=Float16Compression(),
  234. averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
  235. tracker_opts=asdict(tracker_args),
  236. verbose=True,
  237. )
  238. class TrainerWithIndependentShuffling(Trainer):
  239. def get_train_dataloader(self) -> DataLoader:
  240. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  241. torch.manual_seed(hash(local_public_key))
  242. return super().get_train_dataloader()
  243. trainer = TrainerWithIndependentShuffling(
  244. model=model,
  245. args=training_args,
  246. tokenizer=tokenizer,
  247. data_collator=data_collator,
  248. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  249. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  250. optimizers=(optimizer, NoOpScheduler(optimizer)),
  251. callbacks=[
  252. CollaborativeCallback(
  253. dht,
  254. optimizer,
  255. model,
  256. local_public_key,
  257. collaboration_args.statistics_expiration,
  258. collaboration_args.backup_every_steps,
  259. )
  260. ],
  261. )
  262. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  263. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  264. # Training
  265. if training_args.do_train:
  266. latest_checkpoint_dir = max(
  267. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  268. )
  269. trainer.train(model_path=latest_checkpoint_dir)
  270. if __name__ == "__main__":
  271. main()