run_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #!/usr/bin/env python3
  2. import os
  3. import pickle
  4. import sys
  5. from dataclasses import asdict
  6. from pathlib import Path
  7. import torch
  8. import transformers
  9. from datasets import load_from_disk
  10. from torch.utils.data import DataLoader
  11. from torch_optimizer import Lamb
  12. from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
  13. from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
  14. from transformers.optimization import get_linear_schedule_with_warmup
  15. from transformers.trainer import Trainer
  16. from transformers.trainer_utils import is_main_process
  17. from hivemind import DHT, Float16Compression, Optimizer, get_dht_time
  18. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  19. from hivemind.utils.networking import log_visible_maddrs
  20. import utils
  21. from arguments import (
  22. AlbertTrainingArguments,
  23. AveragerArguments,
  24. CollaborationArguments,
  25. DatasetArguments,
  26. ProgressTrackerArguments,
  27. )
  28. use_hivemind_log_handler("in_root_logger")
  29. logger = get_logger(__name__)
  30. LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
  31. def setup_transformers_logging(process_rank: int):
  32. if is_main_process(process_rank):
  33. transformers.utils.logging.set_verbosity_info()
  34. transformers.utils.logging.disable_default_handler()
  35. transformers.utils.logging.enable_propagation()
  36. def get_model(training_args, config, tokenizer):
  37. # Find latest checkpoint in output_dir
  38. output_dir = Path(training_args.output_dir)
  39. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  40. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  41. if latest_checkpoint_dir is not None:
  42. logger.info(f"Loading model from {latest_checkpoint_dir}")
  43. model = AlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  44. else:
  45. logger.info(f"Training from scratch")
  46. model = AlbertForPreTraining(config)
  47. model.resize_token_embeddings(len(tokenizer))
  48. return model
  49. class CollaborativeCallback(transformers.TrainerCallback):
  50. """
  51. This callback monitors and reports collaborative training progress.
  52. In case of a catastrophic failure, it can also revert training to a backup.
  53. """
  54. def __init__(
  55. self,
  56. dht: DHT,
  57. optimizer: Optimizer,
  58. model: torch.nn.Module,
  59. local_public_key: bytes,
  60. statistics_expiration: float,
  61. backup_every_steps: int,
  62. ):
  63. super().__init__()
  64. self.model = model
  65. self.dht, self.optimizer = dht, optimizer
  66. self.local_public_key = local_public_key
  67. self.statistics_expiration = statistics_expiration
  68. self.last_reported_collaboration_step = -1
  69. self.samples = 0
  70. self.steps = 0
  71. self.loss = 0
  72. self.total_samples_processed = 0
  73. self.backup_every_steps = backup_every_steps
  74. self.latest_backup = self.backup_state()
  75. def on_train_begin(
  76. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  77. ):
  78. logger.info("Loading state from peers")
  79. self.optimizer.load_state_from_peers()
  80. def on_step_end(
  81. self, args: TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs
  82. ):
  83. control.should_log = True
  84. if not self.params_are_finite():
  85. self.restore_from_backup(self.latest_backup)
  86. return control
  87. local_progress = self.optimizer.local_progress
  88. if state.log_history:
  89. self.loss += state.log_history[-1]["loss"]
  90. self.steps += 1
  91. if self.optimizer.local_epoch != self.last_reported_collaboration_step:
  92. self.last_reported_collaboration_step = self.optimizer.local_epoch
  93. self.total_samples_processed += self.samples
  94. samples_per_second = local_progress.samples_per_second
  95. statistics = utils.LocalMetrics(
  96. step=self.optimizer.local_epoch,
  97. samples_per_second=samples_per_second,
  98. samples_accumulated=self.samples,
  99. loss=self.loss,
  100. mini_steps=self.steps,
  101. )
  102. logger.info(f"Step #{self.optimizer.local_epoch}")
  103. logger.info(f"Your current contribution: {self.total_samples_processed} samples")
  104. logger.info(f"Performance: {samples_per_second:.3f} samples/sec")
  105. if self.steps:
  106. logger.info(f"Local loss: {self.loss / self.steps:.5f}")
  107. if self.optimizer.local_epoch % self.backup_every_steps == 0:
  108. self.latest_backup = self.backup_state()
  109. self.loss = 0
  110. self.steps = 0
  111. if self.optimizer.is_synchronized_with_peers():
  112. self.dht.store(
  113. key=self.optimizer.run_id + "_metrics",
  114. subkey=self.local_public_key,
  115. value=statistics.dict(),
  116. expiration_time=get_dht_time() + self.statistics_expiration,
  117. return_future=True,
  118. )
  119. self.samples = local_progress.samples_accumulated
  120. return control
  121. @torch.no_grad()
  122. def params_are_finite(self):
  123. for param in self.model.parameters():
  124. if not torch.all(torch.isfinite(param)):
  125. return False
  126. return True
  127. @torch.no_grad()
  128. def backup_state(self) -> bytes:
  129. return pickle.dumps({"model": self.model.state_dict(), "optimizer": self.optimizer.state_dict()})
  130. @torch.no_grad()
  131. def restore_from_backup(self, backup: bytes):
  132. state = pickle.loads(backup)
  133. self.model.load_state_dict(state["model"])
  134. self.optimizer.load_state_dict(state["optimizer"])
  135. class NoOpScheduler(LRSchedulerBase):
  136. """Dummy scheduler for transformers.Trainer. The real scheduler is defined in Optimizer.scheduler"""
  137. def get_lr(self):
  138. return [group["lr"] for group in self.optimizer.param_groups]
  139. def print_lr(self, *args, **kwargs):
  140. if self.optimizer.scheduler:
  141. return self.optimizer.scheduler.print_lr(*args, **kwargs)
  142. def step(self):
  143. self._last_lr = self.get_lr()
  144. def state_dict(self):
  145. return {}
  146. def load_state_dict(self, *args, **kwargs):
  147. logger.debug("Called NoOpScheduler.load_state_dict")
  148. def main():
  149. parser = HfArgumentParser(
  150. (
  151. AlbertTrainingArguments,
  152. DatasetArguments,
  153. CollaborationArguments,
  154. AveragerArguments,
  155. ProgressTrackerArguments,
  156. )
  157. )
  158. training_args, dataset_args, collaboration_args, averager_args, tracker_args = parser.parse_args_into_dataclasses()
  159. logger.info(f"Found {len(collaboration_args.initial_peers)} initial peers: {collaboration_args.initial_peers}")
  160. setup_transformers_logging(training_args.local_rank)
  161. logger.info(f"Training/evaluation parameters:\n{training_args}")
  162. # Set seed before initializing model.
  163. set_seed(training_args.seed)
  164. config = AlbertConfig.from_pretrained(dataset_args.config_path, cache_dir=dataset_args.cache_dir)
  165. try:
  166. tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
  167. except OSError:
  168. logger.fatal(
  169. f"No tokenizer data found in {dataset_args.tokenizer_path}, "
  170. f"please run ./tokenize_wikitext103.py before running this"
  171. )
  172. sys.exit(1)
  173. model = get_model(training_args, config, tokenizer)
  174. model.to(training_args.device)
  175. tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
  176. # This data collator will take care of randomly masking the tokens.
  177. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer)
  178. validators, local_public_key = utils.make_validators(collaboration_args.run_id)
  179. dht = DHT(
  180. start=True,
  181. initial_peers=collaboration_args.initial_peers,
  182. client_mode=collaboration_args.client_mode,
  183. record_validators=validators,
  184. use_ipfs=collaboration_args.use_ipfs,
  185. host_maddrs=collaboration_args.host_maddrs,
  186. announce_maddrs=collaboration_args.announce_maddrs,
  187. identity_path=collaboration_args.identity_path,
  188. )
  189. log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=collaboration_args.use_ipfs)
  190. total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
  191. if torch.cuda.device_count() != 0:
  192. total_batch_size_per_step *= torch.cuda.device_count()
  193. adjusted_target_batch_size = collaboration_args.target_batch_size - collaboration_args.batch_size_lead
  194. # We need to make such a lambda function instead of just an optimizer instance
  195. # to make hivemind.Optimizer(..., offload_optimizer=True) work
  196. opt = lambda params: Lamb(
  197. params,
  198. lr=training_args.learning_rate,
  199. betas=(training_args.adam_beta1, training_args.adam_beta2),
  200. eps=training_args.adam_epsilon,
  201. weight_decay=training_args.weight_decay,
  202. clamp_value=training_args.clamp_value,
  203. debias=True,
  204. )
  205. no_decay = ["bias", "LayerNorm.weight"]
  206. params = [
  207. {
  208. "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
  209. "weight_decay": training_args.weight_decay,
  210. },
  211. {
  212. "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
  213. "weight_decay": 0.0,
  214. },
  215. ]
  216. scheduler = lambda opt: get_linear_schedule_with_warmup(
  217. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  218. )
  219. optimizer = Optimizer(
  220. dht=dht,
  221. run_id=collaboration_args.run_id,
  222. target_batch_size=adjusted_target_batch_size,
  223. batch_size_per_step=total_batch_size_per_step,
  224. optimizer=opt,
  225. params=params,
  226. scheduler=scheduler,
  227. matchmaking_time=collaboration_args.matchmaking_time,
  228. averaging_timeout=collaboration_args.averaging_timeout,
  229. offload_optimizer=True,
  230. delay_optimizer_step=True,
  231. delay_grad_averaging=True,
  232. client_mode=collaboration_args.client_mode,
  233. grad_compression=Float16Compression(),
  234. state_averaging_compression=Float16Compression(),
  235. averager_opts={"bandwidth": collaboration_args.bandwidth, **asdict(averager_args)},
  236. tracker_opts=asdict(tracker_args),
  237. verbose=True,
  238. )
  239. class TrainerWithIndependentShuffling(Trainer):
  240. def get_train_dataloader(self) -> DataLoader:
  241. """Shuffle data independently for each peer to avoid duplicating batches [important for quality]"""
  242. torch.manual_seed(hash(local_public_key))
  243. return super().get_train_dataloader()
  244. trainer = TrainerWithIndependentShuffling(
  245. model=model,
  246. args=training_args,
  247. tokenizer=tokenizer,
  248. data_collator=data_collator,
  249. train_dataset=tokenized_datasets["train"] if training_args.do_train else None,
  250. eval_dataset=tokenized_datasets["validation"] if training_args.do_eval else None,
  251. optimizers=(optimizer, NoOpScheduler(optimizer)),
  252. callbacks=[
  253. CollaborativeCallback(
  254. dht,
  255. optimizer,
  256. model,
  257. local_public_key,
  258. collaboration_args.statistics_expiration,
  259. collaboration_args.backup_every_steps,
  260. )
  261. ],
  262. )
  263. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  264. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  265. # Training
  266. if training_args.do_train:
  267. latest_checkpoint_dir = max(
  268. Path(training_args.output_dir).glob("checkpoint*"), default=None, key=os.path.getctime
  269. )
  270. trainer.train(model_path=latest_checkpoint_dir)
  271. if __name__ == "__main__":
  272. main()