task.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import transformers
  8. from dalle_pytorch import DALLE
  9. from dalle_pytorch.vae import VQGanVAE
  10. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  11. from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
  12. from torch import nn
  13. import utils
  14. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  15. from data import make_dataset
  16. from huggingface_auth import authorize_with_huggingface
  17. from lib.training.clipped_lamb import LambWithGradientClipping
  18. from lib.training.offload import OffloadOptimizer
  19. logger = hivemind.get_logger(__name__)
  20. class VQGanParams(VQGanVAE):
  21. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  22. nn.Module.__init__(self)
  23. self.num_layers = num_layers
  24. self.image_size = image_size
  25. self.num_tokens = num_tokens
  26. self.is_gumbel = is_gumbel
  27. class ModelWrapper(nn.Module):
  28. def __init__(self, model):
  29. super().__init__()
  30. self.model = model
  31. def forward(self, input_ids, attention_mask, image):
  32. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  33. return {'loss': loss}
  34. class TrainingTask:
  35. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  36. _dht = _collaborative_optimizer = _training_dataset = None
  37. def __init__(
  38. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  39. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  40. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  41. transformers.set_seed(trainer_args.seed) # seed used for initialization
  42. self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  43. self.tokenizer.pad_token = self.tokenizer.eos_token
  44. output_dir = Path(trainer_args.output_dir)
  45. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  46. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  47. if latest_checkpoint_dir is None:
  48. logger.info(f"Creating model")
  49. depth = 64
  50. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  51. attn_types.append('conv_like')
  52. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  53. shared_layer_ids.append('w_conv')
  54. dalle = DALLE(
  55. vae=VQGanParams(),
  56. num_text_tokens=self.tokenizer.vocab_size,
  57. text_seq_len=trainer_args.text_seq_length,
  58. dim=1024,
  59. depth=depth,
  60. heads=16,
  61. dim_head=64,
  62. attn_types=attn_types,
  63. ff_dropout=0,
  64. attn_dropout=0,
  65. shared_attn_ids=shared_layer_ids,
  66. shared_ff_ids=shared_layer_ids,
  67. rotary_emb=True,
  68. reversible=True,
  69. share_input_output_emb=True,
  70. )
  71. logger.info(f"Trainable parameters: "
  72. f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
  73. self.model = ModelWrapper(dalle)
  74. else:
  75. logger.info(f"Loading model from {latest_checkpoint_dir}")
  76. self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  77. @property
  78. def dht(self):
  79. if self._dht is None:
  80. self._dht = hivemind.DHT(
  81. start=True,
  82. initial_peers=self.peer_args.initial_peers,
  83. client_mode=self.peer_args.client_mode,
  84. host_maddrs=self.peer_args.host_maddrs,
  85. announce_maddrs=self.peer_args.announce_maddrs,
  86. use_ipfs=self.peer_args.use_ipfs,
  87. record_validators=self.validators,
  88. identity_path=self.peer_args.identity_path,
  89. authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
  90. )
  91. if self.peer_args.client_mode:
  92. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  93. else:
  94. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  95. return self._dht
  96. @property
  97. def collaborative_optimizer(self):
  98. if self._collaborative_optimizer is None:
  99. opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  100. averaging_compression = SizeAdaptiveCompression(
  101. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  102. state_compression = hivemind.Float16Compression()
  103. self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
  104. dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
  105. batch_size_per_step=self.trainer_args.batch_size_per_step,
  106. compression=averaging_compression, state_compression=state_compression,
  107. client_mode=self.peer_args.client_mode, verbose=True, start=True, **asdict(self.collab_args))
  108. return self._collaborative_optimizer
  109. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  110. no_decay = ["bias", "LayerNorm.weight"]
  111. optimizer_grouped_parameters = [
  112. {
  113. "params": [p for n, p in self.model.named_parameters()
  114. if not any(nd in n for nd in no_decay) and p.requires_grad],
  115. "weight_decay": training_args.weight_decay,
  116. },
  117. {
  118. "params": [p for n, p in self.model.named_parameters()
  119. if any(nd in n for nd in no_decay) and p.requires_grad],
  120. "weight_decay": 0.0,
  121. },
  122. ]
  123. opt = OffloadOptimizer(
  124. optimizer_grouped_parameters,
  125. optim_cls=LambWithGradientClipping,
  126. lr=training_args.learning_rate,
  127. betas=(training_args.adam_beta1, training_args.adam_beta2),
  128. eps=training_args.adam_epsilon,
  129. weight_decay=training_args.weight_decay,
  130. max_grad_norm=training_args.max_grad_norm,
  131. clamp_value=training_args.clamp_value,
  132. debias=True,
  133. )
  134. scheduler = get_linear_schedule_with_warmup(
  135. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  136. )
  137. return opt, scheduler
  138. @property
  139. def training_dataset(self):
  140. if self._training_dataset is None:
  141. self._training_dataset = make_dataset(
  142. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  143. max_sequence_length=self.trainer_args.text_seq_length
  144. )
  145. return self._training_dataset
  146. @property
  147. def data_collator(self):
  148. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  149. padding='max_length', max_length=self.trainer_args.text_seq_length)