task.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import os
  2. from dataclasses import asdict
  3. from pathlib import Path
  4. import hivemind
  5. import transformers
  6. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  7. from transformers import AlbertTokenizerFast, get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
  8. import utils
  9. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  10. from data import make_dataset
  11. from huggingface_auth import authorize_with_huggingface
  12. from lib import LeanAlbertConfig, LeanAlbertForPreTraining
  13. from lib.staging.collaborative import CollaborativeOptimizer
  14. from lib.training.clipped_lamb import LambWithGradientClipping
  15. from lib.training.offload import OffloadOptimizer
  16. hivemind.use_hivemind_log_handler("in_root_logger")
  17. logger = hivemind.get_logger()
  18. class TrainingTask:
  19. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  20. _dht = _collaborative_optimizer = _training_dataset = None
  21. def __init__(
  22. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  23. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  24. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  25. transformers.set_seed(trainer_args.seed) # seed used for initialization
  26. self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
  27. self.tokenizer = AlbertTokenizerFast.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
  28. output_dir = Path(trainer_args.output_dir)
  29. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  30. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  31. if latest_checkpoint_dir is None:
  32. logger.info(f"Creating model")
  33. self.model = LeanAlbertForPreTraining(self.config)
  34. self.model.resize_token_embeddings(len(self.tokenizer))
  35. else:
  36. logger.info(f"Loading model from {latest_checkpoint_dir}")
  37. self.model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
  38. @property
  39. def dht(self):
  40. if self._dht is None:
  41. self._dht = hivemind.DHT(
  42. start=True,
  43. initial_peers=self.peer_args.initial_peers,
  44. client_mode=self.peer_args.client_mode,
  45. host_maddrs=self.peer_args.host_maddrs,
  46. announce_maddrs=self.peer_args.announce_maddrs,
  47. use_ipfs=self.peer_args.use_ipfs,
  48. record_validators=self.validators,
  49. identity_path=self.peer_args.identity_path,
  50. authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
  51. )
  52. if self.peer_args.client_mode:
  53. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  54. else:
  55. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  56. return self._dht
  57. @property
  58. def collaborative_optimizer(self):
  59. if self._collaborative_optimizer is None:
  60. opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  61. averaging_compression = SizeAdaptiveCompression(
  62. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  63. state_compression = hivemind.Float16Compression()
  64. self._collaborative_optimizer = CollaborativeOptimizer(
  65. dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
  66. batch_size_per_step=self.trainer_args.batch_size_per_step,
  67. compression=averaging_compression, state_compression=state_compression,
  68. client_mode=self.peer_args.client_mode, verbose=True, start=True, **asdict(self.collab_args))
  69. return self._collaborative_optimizer
  70. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  71. no_decay = ["bias", "LayerNorm.weight"]
  72. optimizer_grouped_parameters = [
  73. {
  74. "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
  75. "weight_decay": training_args.weight_decay,
  76. },
  77. {
  78. "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
  79. "weight_decay": 0.0,
  80. },
  81. ]
  82. opt = OffloadOptimizer(
  83. optimizer_grouped_parameters,
  84. optim_cls=LambWithGradientClipping,
  85. lr=training_args.learning_rate,
  86. betas=(training_args.adam_beta1, training_args.adam_beta2),
  87. eps=training_args.adam_epsilon,
  88. weight_decay=training_args.weight_decay,
  89. max_grad_norm=training_args.max_grad_norm,
  90. clamp_value=training_args.clamp_value,
  91. debias=True,
  92. )
  93. scheduler = get_linear_schedule_with_warmup(
  94. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  95. )
  96. return opt, scheduler
  97. @property
  98. def training_dataset(self):
  99. if self._training_dataset is None:
  100. self._training_dataset = make_dataset(
  101. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  102. max_sequence_length=self.trainer_args.seq_length
  103. )
  104. return self._training_dataset
  105. @property
  106. def data_collator(self):
  107. return DataCollatorForLanguageModeling(
  108. tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
  109. )