task.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import transformers
  8. from dalle_pytorch import DALLE
  9. from dalle_pytorch.vae import VQGanVAE
  10. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  11. from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
  12. from torch import nn
  13. from transformers import training_args
  14. import utils
  15. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  16. from data import make_dataset
  17. from huggingface_auth import authorize_with_huggingface
  18. from lib.training.lamb_8bit import CPULAMB8Bit
  19. logger = hivemind.get_logger(__name__)
  20. class VQGanParams(VQGanVAE):
  21. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  22. nn.Module.__init__(self)
  23. self.num_layers = num_layers
  24. self.image_size = image_size
  25. self.num_tokens = num_tokens
  26. self.is_gumbel = is_gumbel
  27. class ModelWrapper(nn.Module):
  28. def __init__(self, model):
  29. super().__init__()
  30. self.model = model
  31. def forward(self, input_ids, attention_mask, image):
  32. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  33. return {'loss': loss}
  34. class TrainingTask:
  35. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  36. _dht = _collaborative_optimizer = _training_dataset = None
  37. def __init__(
  38. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  39. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  40. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  41. transformers.set_seed(trainer_args.seed) # seed used for initialization
  42. self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  43. self.tokenizer.pad_token = self.tokenizer.eos_token
  44. output_dir = Path(trainer_args.output_dir)
  45. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  46. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  47. if latest_checkpoint_dir is None:
  48. logger.info(f"Creating model")
  49. depth = 64
  50. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  51. attn_types.append('conv_like')
  52. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  53. shared_layer_ids.append('w_conv')
  54. dalle = DALLE(
  55. vae=VQGanParams(),
  56. num_text_tokens=self.tokenizer.vocab_size,
  57. text_seq_len=trainer_args.text_seq_length,
  58. dim=1024,
  59. depth=depth,
  60. heads=16,
  61. dim_head=64,
  62. attn_types=attn_types,
  63. ff_dropout=0,
  64. attn_dropout=0,
  65. shared_attn_ids=shared_layer_ids,
  66. shared_ff_ids=shared_layer_ids,
  67. rotary_emb=True,
  68. reversible=True,
  69. share_input_output_emb=True,
  70. )
  71. logger.info(f"Trainable parameters: "
  72. f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
  73. self.model = ModelWrapper(dalle)
  74. else:
  75. logger.info(f"Loading model from {latest_checkpoint_dir}")
  76. self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  77. @property
  78. def dht(self):
  79. if self._dht is None:
  80. if self.peer_args.authorize:
  81. authorizer = authorize_with_huggingface()
  82. self.trainer_args.run_name = authorizer.username # For wandb
  83. else:
  84. authorizer = None
  85. self._dht = hivemind.DHT(
  86. start=True,
  87. initial_peers=self.peer_args.initial_peers,
  88. client_mode=self.peer_args.client_mode,
  89. host_maddrs=self.peer_args.host_maddrs,
  90. announce_maddrs=self.peer_args.announce_maddrs,
  91. use_ipfs=self.peer_args.use_ipfs,
  92. record_validators=self.validators,
  93. identity_path=self.peer_args.identity_path,
  94. authorizer=authorizer,
  95. )
  96. if self.peer_args.client_mode:
  97. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  98. else:
  99. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  100. return self._dht
  101. @property
  102. def collaborative_optimizer(self):
  103. if self._collaborative_optimizer is None:
  104. params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  105. averaging_compression = SizeAdaptiveCompression(
  106. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  107. self._collaborative_optimizer = hivemind.Optimizer(
  108. dht=self.dht, run_id=self.peer_args.experiment_prefix,
  109. params=params, optimizer=opt, scheduler=scheduler,
  110. offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
  111. batch_size_per_step=self.trainer_args.batch_size_per_step,
  112. grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
  113. client_mode=self.peer_args.client_mode, verbose=True,
  114. **asdict(self.collab_args))
  115. return self._collaborative_optimizer
  116. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  117. no_decay = ["bias", "LayerNorm.weight"]
  118. params = [
  119. {
  120. "params": [p for n, p in self.model.named_parameters()
  121. if not any(nd in n for nd in no_decay) and p.requires_grad],
  122. "weight_decay": training_args.weight_decay,
  123. },
  124. {
  125. "params": [p for n, p in self.model.named_parameters()
  126. if any(nd in n for nd in no_decay) and p.requires_grad],
  127. "weight_decay": 0.0,
  128. },
  129. ]
  130. opt = lambda params: CPULAMB8Bit(
  131. params,
  132. lr=training_args.learning_rate,
  133. betas=(training_args.adam_beta1, training_args.adam_beta2),
  134. eps=training_args.adam_epsilon,
  135. weight_decay=training_args.weight_decay,
  136. max_grad_norm=training_args.max_grad_norm,
  137. clamp_value=training_args.clamp_value,
  138. reuse_grad_buffers=True,
  139. )
  140. scheduler = lambda opt: get_linear_schedule_with_warmup(
  141. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  142. )
  143. return params, opt, scheduler
  144. @property
  145. def training_dataset(self):
  146. if self._training_dataset is None:
  147. self._training_dataset = make_dataset(
  148. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  149. max_sequence_length=self.trainer_args.text_seq_length
  150. )
  151. return self._training_dataset
  152. @property
  153. def data_collator(self):
  154. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  155. padding='max_length', max_length=self.trainer_args.text_seq_length)