task.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import transformers
  8. from dalle_pytorch import DALLE
  9. from dalle_pytorch.vae import VQGanVAE
  10. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  11. from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
  12. from torch import nn
  13. from transformers import training_args
  14. import utils
  15. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  16. from data import make_dataset
  17. from huggingface_auth import authorize_with_huggingface
  18. from lib.training.lamb_8bit import CPULAMB8Bit
  19. logger = hivemind.get_logger(__name__)
  20. class VQGanParams(VQGanVAE):
  21. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  22. nn.Module.__init__(self)
  23. self.num_layers = num_layers
  24. self.image_size = image_size
  25. self.num_tokens = num_tokens
  26. self.is_gumbel = is_gumbel
  27. class ModelWrapper(nn.Module):
  28. def __init__(self, model):
  29. super().__init__()
  30. self.model = model
  31. def tie_weights(self):
  32. pass
  33. def forward(self, input_ids, attention_mask, image):
  34. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  35. return {'loss': loss}
  36. class TrainingTask:
  37. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  38. _dht = _collaborative_optimizer = _training_dataset = None
  39. def __init__(
  40. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  41. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  42. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  43. transformers.set_seed(trainer_args.seed) # seed used for initialization
  44. self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  45. self.tokenizer.pad_token = self.tokenizer.eos_token
  46. output_dir = Path(trainer_args.output_dir)
  47. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  48. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  49. if latest_checkpoint_dir is None:
  50. logger.info(f"Creating model")
  51. depth = 16#TODO
  52. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  53. attn_types.append('conv_like')
  54. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  55. shared_layer_ids.append('w_conv')
  56. dalle = DALLE(
  57. vae=VQGanParams(),
  58. num_text_tokens=self.tokenizer.vocab_size,
  59. text_seq_len=trainer_args.text_seq_length,
  60. dim=1024,
  61. depth=depth,
  62. heads=16,
  63. dim_head=64,
  64. attn_types=attn_types,
  65. ff_dropout=0,
  66. attn_dropout=0,
  67. shared_attn_ids=shared_layer_ids,
  68. shared_ff_ids=shared_layer_ids,
  69. rotary_emb=True,
  70. reversible=True,
  71. share_input_output_emb=True,
  72. )
  73. logger.info(f"Trainable parameters: "
  74. f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
  75. self.model = ModelWrapper(dalle)
  76. else:
  77. logger.info(f"Loading model from {latest_checkpoint_dir}")
  78. self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  79. @property
  80. def dht(self):
  81. if self._dht is None:
  82. if self.peer_args.authorize:
  83. authorizer = authorize_with_huggingface()
  84. self.trainer_args.run_name = authorizer.username # For wandb
  85. else:
  86. authorizer = None
  87. self._dht = hivemind.DHT(
  88. start=True,
  89. initial_peers=self.peer_args.initial_peers,
  90. client_mode=self.peer_args.client_mode,
  91. host_maddrs=self.peer_args.host_maddrs,
  92. announce_maddrs=self.peer_args.announce_maddrs,
  93. use_ipfs=self.peer_args.use_ipfs,
  94. record_validators=self.validators,
  95. identity_path=self.peer_args.identity_path,
  96. authorizer=authorizer,
  97. )
  98. if self.peer_args.client_mode:
  99. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  100. else:
  101. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  102. return self._dht
  103. @property
  104. def collaborative_optimizer(self):
  105. if self._collaborative_optimizer is None:
  106. params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  107. averaging_compression = SizeAdaptiveCompression(
  108. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  109. self._collaborative_optimizer = hivemind.Optimizer(
  110. dht=self.dht, run_id=self.peer_args.experiment_prefix,
  111. params=params, optimizer=opt, scheduler=scheduler,
  112. offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
  113. batch_size_per_step=self.trainer_args.batch_size_per_step,
  114. grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
  115. client_mode=self.peer_args.client_mode, verbose=True,
  116. **asdict(self.collab_args))
  117. return self._collaborative_optimizer
  118. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  119. no_decay = ["bias", "LayerNorm.weight"]
  120. params = [
  121. {
  122. "params": [p for n, p in self.model.named_parameters()
  123. if not any(nd in n for nd in no_decay) and p.requires_grad],
  124. "weight_decay": training_args.weight_decay,
  125. },
  126. {
  127. "params": [p for n, p in self.model.named_parameters()
  128. if any(nd in n for nd in no_decay) and p.requires_grad],
  129. "weight_decay": 0.0,
  130. },
  131. ]
  132. opt = lambda params: CPULAMB8Bit(
  133. params,
  134. lr=training_args.learning_rate,
  135. betas=(training_args.adam_beta1, training_args.adam_beta2),
  136. eps=training_args.adam_epsilon,
  137. weight_decay=training_args.weight_decay,
  138. max_grad_norm=training_args.max_grad_norm,
  139. clamp_value=training_args.clamp_value,
  140. reuse_grad_buffers=True,
  141. )
  142. scheduler = lambda opt: get_linear_schedule_with_warmup(
  143. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  144. )
  145. return params, opt, scheduler
  146. @property
  147. def training_dataset(self):
  148. if self._training_dataset is None:
  149. self._training_dataset = make_dataset(
  150. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  151. max_sequence_length=self.trainer_args.text_seq_length
  152. )
  153. return self._training_dataset
  154. @property
  155. def data_collator(self):
  156. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  157. padding='max_length', max_length=self.trainer_args.text_seq_length)