task.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import transformers
  8. from dalle_pytorch import DALLE
  9. from dalle_pytorch.vae import VQGanVAE
  10. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  11. from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
  12. from torch import nn
  13. import utils
  14. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  15. from data import make_dataset
  16. from huggingface_auth import authorize_with_huggingface
  17. from lib.training.lamb_8bit import CPULAMB8Bit
  18. logger = hivemind.get_logger(__name__)
  19. class VQGanParams(VQGanVAE):
  20. def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
  21. nn.Module.__init__(self)
  22. self.num_layers = num_layers
  23. self.image_size = image_size
  24. self.num_tokens = num_tokens
  25. self.is_gumbel = is_gumbel
  26. class ModelWrapper(nn.Module):
  27. def __init__(self, model):
  28. super().__init__()
  29. self.model = model
  30. def forward(self, input_ids, attention_mask, image):
  31. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  32. return {'loss': loss}
  33. class TrainingTask:
  34. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  35. _dht = _collaborative_optimizer = _training_dataset = None
  36. def __init__(
  37. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  38. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  39. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  40. transformers.set_seed(trainer_args.seed) # seed used for initialization
  41. self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  42. self.tokenizer.pad_token = self.tokenizer.eos_token
  43. output_dir = Path(trainer_args.output_dir)
  44. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  45. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  46. if latest_checkpoint_dir is None:
  47. logger.info(f"Creating model")
  48. depth = 64
  49. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  50. attn_types.append('conv_like')
  51. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  52. shared_layer_ids.append('w_conv')
  53. dalle = DALLE(
  54. vae=VQGanParams(),
  55. num_text_tokens=self.tokenizer.vocab_size,
  56. text_seq_len=trainer_args.text_seq_length,
  57. dim=1024,
  58. depth=depth,
  59. heads=16,
  60. dim_head=64,
  61. attn_types=attn_types,
  62. ff_dropout=0,
  63. attn_dropout=0,
  64. shared_attn_ids=shared_layer_ids,
  65. shared_ff_ids=shared_layer_ids,
  66. rotary_emb=True,
  67. reversible=True,
  68. share_input_output_emb=True,
  69. )
  70. logger.info(f"Trainable parameters: "
  71. f"{sum(param.numel() for param in dalle.parameters() if param.requires_grad)}")
  72. self.model = ModelWrapper(dalle)
  73. else:
  74. logger.info(f"Loading model from {latest_checkpoint_dir}")
  75. self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  76. @property
  77. def dht(self):
  78. if self._dht is None:
  79. self._dht = hivemind.DHT(
  80. start=True,
  81. initial_peers=self.peer_args.initial_peers,
  82. client_mode=self.peer_args.client_mode,
  83. host_maddrs=self.peer_args.host_maddrs,
  84. announce_maddrs=self.peer_args.announce_maddrs,
  85. use_ipfs=self.peer_args.use_ipfs,
  86. record_validators=self.validators,
  87. identity_path=self.peer_args.identity_path,
  88. authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
  89. )
  90. if self.peer_args.client_mode:
  91. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  92. else:
  93. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  94. return self._dht
  95. @property
  96. def collaborative_optimizer(self):
  97. if self._collaborative_optimizer is None:
  98. params, opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  99. averaging_compression = SizeAdaptiveCompression(
  100. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  101. self._collaborative_optimizer = hivemind.Optimizer(
  102. dht=self.dht, run_id=self.peer_args.experiment_prefix,
  103. params=params, optimizer=opt, scheduler=scheduler,
  104. offload_optimizer=True, delay_grad_averaging=False, delay_optimizer_step=True,
  105. batch_size_per_step=self.trainer_args.batch_size_per_step,
  106. grad_compression=averaging_compression, state_averaging_compression=averaging_compression,
  107. client_mode=self.peer_args.client_mode, verbose=True,
  108. **asdict(self.collab_args))
  109. return self._collaborative_optimizer
  110. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  111. no_decay = ["bias", "LayerNorm.weight"]
  112. params = [
  113. {
  114. "params": [p for n, p in self.model.named_parameters()
  115. if not any(nd in n for nd in no_decay) and p.requires_grad],
  116. "weight_decay": training_args.weight_decay,
  117. },
  118. {
  119. "params": [p for n, p in self.model.named_parameters()
  120. if any(nd in n for nd in no_decay) and p.requires_grad],
  121. "weight_decay": 0.0,
  122. },
  123. ]
  124. opt = lambda params: CPULAMB8Bit(
  125. params,
  126. lr=training_args.learning_rate,
  127. betas=(training_args.adam_beta1, training_args.adam_beta2),
  128. eps=training_args.adam_epsilon,
  129. weight_decay=training_args.weight_decay,
  130. max_grad_norm=training_args.max_grad_norm,
  131. clamp_value=training_args.clamp_value,
  132. reuse_grad_buffers=True,
  133. )
  134. scheduler = lambda opt: get_linear_schedule_with_warmup(
  135. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  136. )
  137. return params, opt, scheduler
  138. @property
  139. def training_dataset(self):
  140. if self._training_dataset is None:
  141. self._training_dataset = make_dataset(
  142. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  143. max_sequence_length=self.trainer_args.text_seq_length
  144. )
  145. return self._training_dataset
  146. @property
  147. def data_collator(self):
  148. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  149. padding='max_length', max_length=self.trainer_args.text_seq_length)