|
@@ -1,23 +1,41 @@
|
|
|
import os
|
|
|
from dataclasses import asdict
|
|
|
+from itertools import cycle, islice
|
|
|
from pathlib import Path
|
|
|
|
|
|
import hivemind
|
|
|
+import torch
|
|
|
import transformers
|
|
|
+from dalle_pytorch import DALLE
|
|
|
+from dalle_pytorch.vae import VQGanVAE, download
|
|
|
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
-from transformers import AlbertTokenizerFast, get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
|
|
|
+from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
|
|
|
+from torch import nn
|
|
|
|
|
|
import utils
|
|
|
from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
|
|
|
from data import make_dataset
|
|
|
from huggingface_auth import authorize_with_huggingface
|
|
|
-from lib import LeanAlbertConfig, LeanAlbertForPreTraining
|
|
|
-from lib.staging.collaborative import CollaborativeOptimizer
|
|
|
from lib.training.clipped_lamb import LambWithGradientClipping
|
|
|
from lib.training.offload import OffloadOptimizer
|
|
|
|
|
|
-hivemind.use_hivemind_log_handler("in_root_logger")
|
|
|
-logger = hivemind.get_logger()
|
|
|
+
|
|
|
+logger = hivemind.get_logger(__name__)
|
|
|
+
|
|
|
+# VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
|
|
|
+# Note: If you change the URLs below, remove ./cache/* to clear the cache
|
|
|
+VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
|
|
|
+VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
|
|
|
+
|
|
|
+
|
|
|
+class ModelWrapper(nn.Module):
|
|
|
+ def __init__(self, model):
|
|
|
+ super().__init__()
|
|
|
+ self.model = model
|
|
|
+
|
|
|
+ def forward(self, input_ids, attention_mask, image):
|
|
|
+ loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
|
|
|
+ return {'loss': loss}
|
|
|
|
|
|
|
|
|
class TrainingTask:
|
|
@@ -30,8 +48,9 @@ class TrainingTask:
|
|
|
self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
|
|
|
self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
|
|
|
transformers.set_seed(trainer_args.seed) # seed used for initialization
|
|
|
- self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
|
|
|
- self.tokenizer = AlbertTokenizerFast.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
|
|
|
+
|
|
|
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
|
|
|
+ self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
output_dir = Path(trainer_args.output_dir)
|
|
|
logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
|
|
@@ -39,11 +58,37 @@ class TrainingTask:
|
|
|
|
|
|
if latest_checkpoint_dir is None:
|
|
|
logger.info(f"Creating model")
|
|
|
- self.model = LeanAlbertForPreTraining(self.config)
|
|
|
- self.model.resize_token_embeddings(len(self.tokenizer))
|
|
|
+
|
|
|
+ vae = VQGanVAE(
|
|
|
+ vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
|
|
|
+ vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
|
|
|
+ )
|
|
|
+
|
|
|
+ depth = 64
|
|
|
+ attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
|
|
|
+ attn_types.append('conv_like')
|
|
|
+ shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
|
|
|
+ shared_layer_ids.append('w_conv')
|
|
|
+ dalle = DALLE(
|
|
|
+ vae=vae,
|
|
|
+ num_text_tokens=self.tokenizer.vocab_size,
|
|
|
+ text_seq_len=trainer_args.text_seq_length,
|
|
|
+ dim=1024,
|
|
|
+ depth=depth,
|
|
|
+ heads=16,
|
|
|
+ dim_head=64,
|
|
|
+ attn_types=attn_types,
|
|
|
+ ff_dropout=0,
|
|
|
+ attn_dropout=0,
|
|
|
+ shared_attn_ids=shared_layer_ids,
|
|
|
+ shared_ff_ids=shared_layer_ids,
|
|
|
+ rotary_emb=False, # FIXME: Fix RuntimeError when True
|
|
|
+ reversible=True,
|
|
|
+ )
|
|
|
+ self.model = ModelWrapper(dalle)
|
|
|
else:
|
|
|
logger.info(f"Loading model from {latest_checkpoint_dir}")
|
|
|
- self.model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
|
|
|
+ self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
|
|
|
|
|
|
@property
|
|
|
def dht(self):
|
|
@@ -72,7 +117,7 @@ class TrainingTask:
|
|
|
averaging_compression = SizeAdaptiveCompression(
|
|
|
threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
|
|
|
state_compression = hivemind.Float16Compression()
|
|
|
- self._collaborative_optimizer = CollaborativeOptimizer(
|
|
|
+ self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
|
|
|
dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
|
|
|
batch_size_per_step=self.trainer_args.batch_size_per_step,
|
|
|
compression=averaging_compression, state_compression=state_compression,
|
|
@@ -83,11 +128,13 @@ class TrainingTask:
|
|
|
no_decay = ["bias", "LayerNorm.weight"]
|
|
|
optimizer_grouped_parameters = [
|
|
|
{
|
|
|
- "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
|
+ "params": [p for n, p in self.model.named_parameters()
|
|
|
+ if not any(nd in n for nd in no_decay) and p.requires_grad],
|
|
|
"weight_decay": training_args.weight_decay,
|
|
|
},
|
|
|
{
|
|
|
- "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
|
+ "params": [p for n, p in self.model.named_parameters()
|
|
|
+ if any(nd in n for nd in no_decay) and p.requires_grad],
|
|
|
"weight_decay": 0.0,
|
|
|
},
|
|
|
]
|
|
@@ -115,12 +162,11 @@ class TrainingTask:
|
|
|
if self._training_dataset is None:
|
|
|
self._training_dataset = make_dataset(
|
|
|
self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
|
|
|
- max_sequence_length=self.trainer_args.seq_length
|
|
|
+ max_sequence_length=self.trainer_args.text_seq_length
|
|
|
)
|
|
|
return self._training_dataset
|
|
|
|
|
|
@property
|
|
|
def data_collator(self):
|
|
|
- return DataCollatorForLanguageModeling(
|
|
|
- tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
|
|
|
- )
|
|
|
+ return DataCollatorWithPadding(tokenizer=self.tokenizer,
|
|
|
+ padding='max_length', max_length=self.trainer_args.text_seq_length)
|