task.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import os
  2. from dataclasses import asdict
  3. from itertools import cycle, islice
  4. from pathlib import Path
  5. import hivemind
  6. import torch
  7. import transformers
  8. from dalle_pytorch import DALLE
  9. from dalle_pytorch.vae import VQGanVAE, download
  10. from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
  11. from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
  12. from torch import nn
  13. import utils
  14. from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
  15. from data import make_dataset
  16. from huggingface_auth import authorize_with_huggingface
  17. from lib.training.clipped_lamb import LambWithGradientClipping
  18. from lib.training.offload import OffloadOptimizer
  19. logger = hivemind.get_logger(__name__)
  20. # VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
  21. # Note: If you change the URLs below, remove ./cache/* to clear the cache
  22. VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
  23. VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
  24. class ModelWrapper(nn.Module):
  25. def __init__(self, model):
  26. super().__init__()
  27. self.model = model
  28. def forward(self, input_ids, attention_mask, image):
  29. loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
  30. return {'loss': loss}
  31. class TrainingTask:
  32. """A container that defines the training config, model, tokenizer, optimizer and other local training utilities"""
  33. _dht = _collaborative_optimizer = _training_dataset = None
  34. def __init__(
  35. self, peer_args: BasePeerArguments, trainer_args: HFTrainerArguments, collab_args: CollaborativeArguments):
  36. self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
  37. self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
  38. transformers.set_seed(trainer_args.seed) # seed used for initialization
  39. self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
  40. self.tokenizer.pad_token = self.tokenizer.eos_token
  41. output_dir = Path(trainer_args.output_dir)
  42. logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
  43. latest_checkpoint_dir = max(output_dir.glob("checkpoint*"), default=None, key=os.path.getctime)
  44. if latest_checkpoint_dir is None:
  45. logger.info(f"Creating model")
  46. vae = VQGanVAE(
  47. vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
  48. vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
  49. )
  50. depth = 64
  51. attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
  52. attn_types.append('conv_like')
  53. shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
  54. shared_layer_ids.append('w_conv')
  55. dalle = DALLE(
  56. vae=vae,
  57. num_text_tokens=self.tokenizer.vocab_size,
  58. text_seq_len=trainer_args.text_seq_length,
  59. dim=1024,
  60. depth=depth,
  61. heads=16,
  62. dim_head=64,
  63. attn_types=attn_types,
  64. ff_dropout=0,
  65. attn_dropout=0,
  66. shared_attn_ids=shared_layer_ids,
  67. shared_ff_ids=shared_layer_ids,
  68. rotary_emb=False, # FIXME: Fix RuntimeError when True
  69. reversible=True,
  70. )
  71. self.model = ModelWrapper(dalle)
  72. else:
  73. logger.info(f"Loading model from {latest_checkpoint_dir}")
  74. self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
  75. @property
  76. def dht(self):
  77. if self._dht is None:
  78. self._dht = hivemind.DHT(
  79. start=True,
  80. initial_peers=self.peer_args.initial_peers,
  81. client_mode=self.peer_args.client_mode,
  82. host_maddrs=self.peer_args.host_maddrs,
  83. announce_maddrs=self.peer_args.announce_maddrs,
  84. use_ipfs=self.peer_args.use_ipfs,
  85. record_validators=self.validators,
  86. identity_path=self.peer_args.identity_path,
  87. authorizer=authorize_with_huggingface() if self.peer_args.authorize else None,
  88. )
  89. if self.peer_args.client_mode:
  90. logger.info(f"Created client mode peer with peer_id={self._dht.peer_id}")
  91. else:
  92. utils.log_visible_maddrs(self._dht.get_visible_maddrs(), only_p2p=self.peer_args.use_ipfs)
  93. return self._dht
  94. @property
  95. def collaborative_optimizer(self):
  96. if self._collaborative_optimizer is None:
  97. opt, scheduler = self._get_local_optimizer_and_scheduler(self.trainer_args)
  98. averaging_compression = SizeAdaptiveCompression(
  99. threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
  100. state_compression = hivemind.Float16Compression()
  101. self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
  102. dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
  103. batch_size_per_step=self.trainer_args.batch_size_per_step,
  104. compression=averaging_compression, state_compression=state_compression,
  105. client_mode=self.peer_args.client_mode, verbose=True, start=True, **asdict(self.collab_args))
  106. return self._collaborative_optimizer
  107. def _get_local_optimizer_and_scheduler(self, training_args: HFTrainerArguments):
  108. no_decay = ["bias", "LayerNorm.weight"]
  109. optimizer_grouped_parameters = [
  110. {
  111. "params": [p for n, p in self.model.named_parameters()
  112. if not any(nd in n for nd in no_decay) and p.requires_grad],
  113. "weight_decay": training_args.weight_decay,
  114. },
  115. {
  116. "params": [p for n, p in self.model.named_parameters()
  117. if any(nd in n for nd in no_decay) and p.requires_grad],
  118. "weight_decay": 0.0,
  119. },
  120. ]
  121. opt = OffloadOptimizer(
  122. optimizer_grouped_parameters,
  123. optim_cls=LambWithGradientClipping,
  124. lr=training_args.learning_rate,
  125. betas=(training_args.adam_beta1, training_args.adam_beta2),
  126. eps=training_args.adam_epsilon,
  127. weight_decay=training_args.weight_decay,
  128. max_grad_norm=training_args.max_grad_norm,
  129. clamp_value=training_args.clamp_value,
  130. debias=True,
  131. )
  132. scheduler = get_linear_schedule_with_warmup(
  133. opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
  134. )
  135. return opt, scheduler
  136. @property
  137. def training_dataset(self):
  138. if self._training_dataset is None:
  139. self._training_dataset = make_dataset(
  140. self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
  141. max_sequence_length=self.trainer_args.text_seq_length
  142. )
  143. return self._training_dataset
  144. @property
  145. def data_collator(self):
  146. return DataCollatorWithPadding(tokenizer=self.tokenizer,
  147. padding='max_length', max_length=self.trainer_args.text_seq_length)