task.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import torch.nn as nn
  8. import transformers
  9. from dalle_pytorch import DALLE
  10. from dalle_pytorch.vae import VQGanVAE
  11. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  12. from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
  13. import utils
  14. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  15. from data import make_dataset
  16. from huggingface_auth import authorize_with_huggingface
  17. from lib.training.lamb_8bit import CPULAMB8Bit
  18. logger = hivemind.get_logger(__name__)
  19. class VQGanParams(VQGanVAE):
  20. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  21. nn.Module.__init__(self)
  22. self.num_layers = num_layers
  23. self.image_size = image_size
  24. self.num_tokens = num_tokens
  25. self.is_gumbel = is_gumbel
  26. class ModelWrapper(nn.Module):
  27. def __init__(self, model):
  28. super().__init__()
  29. self.model = model
  30. def forward(self, input_ids, attention_mask, image):
  31. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  32. return {'loss': loss}
  33. class TrainingTask:
  34. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  35. _authorizer = _dht = _collaborative_optimizer = _training_dataset = None
  36. def __init__(
  37. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  38. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  39. self.trainer_args.run_name = self.authorizer.username # For wandb
  40. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  41. transformers.set_seed(trainer_args.seed) # seed used for initialization
  42. self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  43. self.tokenizer.pad_token = self.tokenizer.eos_token
  44. logger.info(f"Creating model")
  45. depth = 64
  46. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  47. attn_types.append('conv_like')
  48. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  49. shared_layer_ids.append('w_conv')
  50. dalle = DALLE(
  51. vae=VQGanParams(),
  52. num_text_tokens=self.tokenizer.vocab_size,
  53. text_seq_len=trainer_args.text_seq_length,
  54. dim=1024,
  55. depth=depth,
  56. heads=16,
  57. dim_head=64,
  58. attn_types=attn_types,
  59. ff_dropout=0,
  60. attn_dropout=0,
  61. shared_attn_ids=shared_layer_ids,
  62. shared_ff_ids=shared_layer_ids,
  63. rotary_emb=True,
  64. reversible=True,
  65. share_input_output_emb=True,
  66. )
  67. logger.info(f"Trainable parameters: "
  68. f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
  69. self.model = ModelWrapper(dalle)
  70. output_dir = Path(trainer_args.output_dir)
  71. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  72. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  73. if latest_checkpoint_dir is not None:
  74. logger.info(f"Loading model from {latest_checkpoint_dir}")
  75. self.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  76. @property
  77. def authorizer(self):
  78. if self._authorizer is None and self.peer_args.authorize:
  79. self._authorizer = authorize_with_huggingface()
  80. return self._authorizer
  81. @property
  82. def dht(self):
  83. if self._dht is None:
  84. self._dht = hivemind.DHT(
  85. start=True,
  86. initial_peers=self.peer_args.initial_peers,
  87. client_mode=self.peer_args.client_mode,
  88. host_maddrs=self.peer_args.host_maddrs,
  89. announce_maddrs=self.peer_args.announce_maddrs,
  90. use_ipfs=self.peer_args.use_ipfs,
  91. record_validators=self.validators,
  92. identity_path=self.peer_args.identity_path,
  93. authorizer=self.authorizer,
  94. )
  95. if self.peer_args.client_mode:
  96. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  97. else:
  98. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  99. return self._dht
  100. @property
  101. def collaborative_optimizer(self):
  102. if self._collaborative_optimizer is None:
  103. params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  104. averaging_compression = SizeAdaptiveCompression(
  105. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  106. self._collaborative_optimizer = hivemind.Optimizer(
  107. dht=self.dht, run_id=self.peer_args.experiment_prefix,
  108. params=params, optimizer=opt, scheduler=scheduler,
  109. offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
  110. batch_size_per_step=self.trainer_args.batch_size_per_step,
  111. grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
  112. client_mode=self.peer_args.client_mode, verbose=True,
  113. **asdict(self.collab_args))
  114. return self._collaborative_optimizer
  115. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  116. no_decay = ["bias", "LayerNorm.weight"]
  117. params = [
  118. {
  119. "params": [p for n, p in self.model.named_parameters()
  120. if not any(nd in n for nd in no_decay) and p.requires_grad],
  121. "weight_decay": training_args.weight_decay,
  122. },
  123. {
  124. "params": [p for n, p in self.model.named_parameters()
  125. if any(nd in n for nd in no_decay) and p.requires_grad],
  126. "weight_decay": 0.0,
  127. },
  128. ]
  129. opt = lambda params: CPULAMB8Bit(
  130. params,
  131. lr=training_args.learning_rate,
  132. betas=(training_args.adam_beta1, training_args.adam_beta2),
  133. eps=training_args.adam_epsilon,
  134. weight_decay=training_args.weight_decay,
  135. max_grad_norm=training_args.max_grad_norm,
  136. clamp_value=training_args.clamp_value,
  137. reuse_grad_buffers=True,
  138. )
  139. scheduler = lambda opt: get_linear_schedule_with_warmup(
  140. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  141. )
  142. return params, opt, scheduler
  143. @property
  144. def training_dataset(self):
  145. if self._training_dataset is None:
  146. self._training_dataset = make_dataset(
  147. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  148. max_sequence_length=self.trainer_args.text_seq_length
  149. )
  150. return self._training_dataset
  151. @property
  152. def data_collator(self):
  153. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  154. padding='max_length', max_length=self.trainer_args.text_seq_length)