Forráskód Böngészése

Use dalle-pytorch instead of LeanAlbert

Aleksandr Borzunov 3 éve
szülő
commit
20e2a3aab2
9 módosított fájl, 112 hozzáadás és 191 törlés
  1. 9 7
      arguments.py
  2. 25 0
      data.py
  3. 0 2
      lib/__init__.py
  4. 4 6
      lib/training/hf_trainer.py
  5. 6 2
      run_aux_peer.py
  6. 5 4
      run_trainer.py
  7. 63 17
      task.py
  8. 0 83
      tests/test_ffn.py
  9. 0 70
      tests/test_rotary.py

+ 9 - 7
arguments.py

@@ -12,15 +12,18 @@ class HFTrainerArguments(TrainingArguments):
     per_device_train_batch_size: int = 1
     per_device_eval_batch_size: int = 1
     gradient_accumulation_steps: int = 1
-    seq_length: int = 512
-    pad_to_multiple_of: int = 8
+    text_seq_length: int = 256
+
+    # DALLE-specific params
+    learning_rate: float = 0.003535
+    adam_beta1: float = 0.9
+    adam_beta2: float = 0.96
+    max_grad_norm: float = 4.0
+    weight_decay: float = 0.045
 
-    learning_rate: float = 0.0025
     total_steps: int = 31250  # total number of collaborative SGD updates, used for learning rate schedule
     warmup_steps: int = 3125
     adam_epsilon: float = 1e-6
-    weight_decay: float = 0.01
-    max_grad_norm: float = 1.0
     clamp_value: float = 10000.0
 
     fp16: bool = False
@@ -103,8 +106,7 @@ class CollaborativeArguments:
 class BasePeerArguments:
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
-    model_config_path: Optional[str] = field(default="./model.json", metadata={"help": "Path to the model config"})
-    tokenizer_path: Optional[str] = field(default="./tokenizer", metadata={"help": "Path to the tokenizer"})
+    tokenizer_path: Optional[str] = field(default="gpt2", metadata={"help": "Path to the tokenizer"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
 
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})

+ 25 - 0
data.py

@@ -0,0 +1,25 @@
+from typing import Optional
+
+import hivemind
+import numpy as np
+from datasets import load_dataset
+
+logger = hivemind.get_logger(__name__)
+
+
+def make_dataset(
+    tokenizer,
+    *,
+    shuffle_buffer_size: int = 10 ** 4,
+    shuffle_seed: Optional[int],
+    preprocessing_batch_size: int = 256,
+    max_sequence_length: int,
+):
+    ds = load_dataset('laion/laion_100m_vqgan_f8', split='train', streaming=True)
+    ds = ds.shuffle(shuffle_buffer_size, seed=shuffle_seed)
+    ds = ds.map(lambda item: dict(
+        tokenizer(item['caption'], truncation=True, max_length=max_sequence_length),
+        image=np.stack([np.frombuffer(encoded, np.int16).astype(np.int64) for encoded in item['code']]),
+    ), batched=True, batch_size=preprocessing_batch_size)
+    ds = ds.with_format('torch')
+    return ds

+ 0 - 2
lib/__init__.py

@@ -1,2 +0,0 @@
-from .modules import *
-from .models import *

+ 4 - 6
lib/training/hf_trainer.py

@@ -1,15 +1,12 @@
 """A catch-all module for the dirty hacks required to make HF Trainer work with collaborative training"""
-from typing import Optional
-
 import torch
 from torch import nn
 from torch.utils.data import DataLoader
 from transformers.trainer import Trainer
+from hivemind import CollaborativeOptimizer
+from hivemind.optim import HivemindGradScaler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
-from lib.staging.collaborative import CollaborativeOptimizer
-from lib.staging.scaler import HivemindGradScaler
-
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger()
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, '_LRScheduler', None)
@@ -74,7 +71,8 @@ class IgnoreGradManipulations(nn.Module):
         return self.module.forward(*args, **kwargs)
 
     def zero_grad(self, set_to_none: bool = False) -> None:
-        if self.override_zero_grad and all(param.grad.isfinite().all() for param in self.parameters()):
+        if self.override_zero_grad and \
+                all(param.grad.isfinite().all() for param in self.parameters() if param.requires_grad):
             logger.debug("Successfully bypassed zero_grad")
         else:
             self.module.zero_grad(set_to_none=set_to_none)

+ 6 - 2
run_aux_peer.py

@@ -4,6 +4,7 @@ import time
 
 import torch
 import wandb
+import transformers
 from transformers import HfArgumentParser
 from huggingface_hub import HfFolder, Repository
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -12,8 +13,11 @@ import utils
 from arguments import AuxiliaryPeerArguments, CollaborativeArguments, HFTrainerArguments
 from task import TrainingTask
 
+
+transformers.utils.logging.disable_default_handler()
+transformers.utils.logging.enable_propagation()
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 class CheckpointHandler:
@@ -56,7 +60,7 @@ class CheckpointHandler:
 
     def upload_checkpoint(self, current_loss):
         logger.info("Saving model")
-        self.task.model.save_pretrained(self.local_path)
+        torch.save(self.task.model.state_dict(), f"{self.local_path}/model_state.pt")
         logger.info("Saving optimizer")
         torch.save(self.task.collaborative_optimizer.opt.state_dict(), f"{self.local_path}/optimizer_state.pt")
         self.previous_timestamp = time.time()

+ 5 - 4
run_trainer.py

@@ -3,7 +3,6 @@
 import os
 from pathlib import Path
 
-
 import transformers
 from transformers import HfArgumentParser
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -16,8 +15,10 @@ from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeAr
 from task import TrainingTask
 
 
+transformers.utils.logging.disable_default_handler()
+transformers.utils.logging.enable_propagation()
 use_hivemind_log_handler("in_root_logger")
-logger = get_logger()
+logger = get_logger(__name__)
 
 
 def main():
@@ -25,8 +26,8 @@ def main():
     training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
 
     logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
-    if len(training_peer_args.initial_peers) == 0:
-        logger.warning("Please specify at least one network endpoint in initial peers.")
+    # if len(training_peer_args.initial_peers) == 0:
+    #     logger.warning("Please specify at least one network endpoint in initial peers.")
 
     utils.setup_logging(trainer_args)
     task = TrainingTask(training_peer_args, trainer_args, collab_args)

+ 63 - 17
task.py

@@ -1,23 +1,41 @@
 import os
 from dataclasses import asdict
+from itertools import cycle, islice
 from pathlib import Path
 
 import hivemind
+import torch
 import transformers
+from dalle_pytorch import DALLE
+from dalle_pytorch.vae import VQGanVAE, download
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
-from transformers import AlbertTokenizerFast, get_linear_schedule_with_warmup, DataCollatorForLanguageModeling
+from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
+from torch import nn
 
 import utils
 from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArguments
 from data import make_dataset
 from huggingface_auth import authorize_with_huggingface
-from lib import LeanAlbertConfig, LeanAlbertForPreTraining
-from lib.staging.collaborative import CollaborativeOptimizer
 from lib.training.clipped_lamb import LambWithGradientClipping
 from lib.training.offload import OffloadOptimizer
 
-hivemind.use_hivemind_log_handler("in_root_logger")
-logger = hivemind.get_logger()
+
+logger = hivemind.get_logger(__name__)
+
+# VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
+# Note: If you change the URLs below, remove ./cache/* to clear the cache
+VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
+VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
+
+
+class ModelWrapper(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.model = model
+
+    def forward(self, input_ids, attention_mask, image):
+        loss = self.model.forward(text=input_ids, image=image, mask=attention_mask, return_loss=True)
+        return {'loss': loss}
 
 
 class TrainingTask:
@@ -30,8 +48,9 @@ class TrainingTask:
         self.peer_args, self.trainer_args, self.collab_args = peer_args, trainer_args, collab_args
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
-        self.config = LeanAlbertConfig.from_pretrained(peer_args.model_config_path)
-        self.tokenizer = AlbertTokenizerFast.from_pretrained(peer_args.tokenizer_path, cache_dir=peer_args.cache_dir)
+
+        self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
+        self.tokenizer.pad_token = self.tokenizer.eos_token
 
         output_dir = Path(trainer_args.output_dir)
         logger.info(f'Checkpoint dir {output_dir}, contents {list(output_dir.glob("checkpoint*"))}')
@@ -39,11 +58,37 @@ class TrainingTask:
 
         if latest_checkpoint_dir is None:
             logger.info(f"Creating model")
-            self.model = LeanAlbertForPreTraining(self.config)
-            self.model.resize_token_embeddings(len(self.tokenizer))
+
+            vae = VQGanVAE(
+                vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
+                vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
+            )
+
+            depth = 64
+            attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
+            attn_types.append('conv_like')
+            shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
+            shared_layer_ids.append('w_conv')
+            dalle = DALLE(
+                vae=vae,
+                num_text_tokens=self.tokenizer.vocab_size,
+                text_seq_len=trainer_args.text_seq_length,
+                dim=1024,
+                depth=depth,
+                heads=16,
+                dim_head=64,
+                attn_types=attn_types,
+                ff_dropout=0,
+                attn_dropout=0,
+                shared_attn_ids=shared_layer_ids,
+                shared_ff_ids=shared_layer_ids,
+                rotary_emb=False,  # FIXME: Fix RuntimeError when True
+                reversible=True,
+            )
+            self.model = ModelWrapper(dalle)
         else:
             logger.info(f"Loading model from {latest_checkpoint_dir}")
-            self.model = LeanAlbertForPreTraining.from_pretrained(latest_checkpoint_dir)
+            self.task.model.load_state_dict(torch.load(f"{latest_checkpoint_dir}/model_state.pt"))
 
     @property
     def dht(self):
@@ -72,7 +117,7 @@ class TrainingTask:
             averaging_compression = SizeAdaptiveCompression(
                 threshold=2 ** 16 + 1, less=Float16Compression(), greater_equal=Uniform8BitQuantization())
             state_compression = hivemind.Float16Compression()
-            self._collaborative_optimizer = CollaborativeOptimizer(
+            self._collaborative_optimizer = hivemind.CollaborativeOptimizer(
                 dht=self.dht, opt=opt, scheduler=scheduler, prefix=self.peer_args.experiment_prefix,
                 batch_size_per_step=self.trainer_args.batch_size_per_step,
                 compression=averaging_compression, state_compression=state_compression,
@@ -83,11 +128,13 @@ class TrainingTask:
         no_decay = ["bias", "LayerNorm.weight"]
         optimizer_grouped_parameters = [
             {
-                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
+                "params": [p for n, p in self.model.named_parameters()
+                           if not any(nd in n for nd in no_decay) and p.requires_grad],
                 "weight_decay": training_args.weight_decay,
             },
             {
-                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
+                "params": [p for n, p in self.model.named_parameters()
+                           if any(nd in n for nd in no_decay) and p.requires_grad],
                 "weight_decay": 0.0,
             },
         ]
@@ -115,12 +162,11 @@ class TrainingTask:
         if self._training_dataset is None:
             self._training_dataset = make_dataset(
                 self.tokenizer, shuffle_seed=hash(self.local_public_key) % 2 ** 31,
-                max_sequence_length=self.trainer_args.seq_length
+                max_sequence_length=self.trainer_args.text_seq_length
             )
         return self._training_dataset
 
     @property
     def data_collator(self):
-        return DataCollatorForLanguageModeling(
-            tokenizer=self.tokenizer, pad_to_multiple_of=self.trainer_args.pad_to_multiple_of
-        )
+        return DataCollatorWithPadding(tokenizer=self.tokenizer,
+                                       padding='max_length', max_length=self.trainer_args.text_seq_length)

+ 0 - 83
tests/test_ffn.py

@@ -1,83 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from lib.modules.ffn import LeanFFN
-
-
-class ReferenceFFN(nn.Module):
-
-    def __init__(self,
-                 hidden_size: int,
-                 intermediate_size: int,
-                 activation=F.gelu,
-                 layer_norm_eps=1e-12,
-                 dropout: float = 0.0):
-        super().__init__()
-        self.dense_i2h = nn.Linear(hidden_size, intermediate_size)
-        self.dense_h2o = nn.Linear(intermediate_size, hidden_size)
-        self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
-        self.activation = activation
-        self.dropout = dropout
-
-    def forward(self, input):
-        output = self.dense_i2h(self.layer_norm(input))
-        output = self.activation(output)
-        output = self.dense_h2o(output)
-        output = F.dropout(output, self.dropout)
-        return output + input
-
-
-def test_ffn_exact_match():
-    torch.use_deterministic_algorithms(True)
-
-    batch_size = 4
-    seq_len = 128
-    dim = 32
-    num_layers = 4
-
-    baseline_ffn = ReferenceFFN(dim, 4 * dim)
-    our_ffn = LeanFFN(dim, 4 * dim)
-
-    assert our_ffn.load_state_dict(baseline_ffn.state_dict())
-
-    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
-
-    # test outputs
-    out_ref = x
-    for i in range(num_layers):
-        out_ref = baseline_ffn.forward(out_ref)
-
-    out_our = x
-    for i in range(num_layers):
-        out_our = our_ffn(out_our)
-
-    assert torch.allclose(out_our, out_ref)
-
-    # test grad inputs
-    obj = (out_ref * (out_ref + 1)).square().mean()
-    grad_ref, = torch.autograd.grad(obj, x)
-
-    obj = (out_our * (out_our + 1)).square().mean()
-    grad_our, = torch.autograd.grad(obj, x)
-    assert torch.allclose(grad_ref, grad_our)
-
-    # test grad params
-    x = torch.rand(batch_size, seq_len, dim, device='cpu', requires_grad=True)
-
-    out_ref = x
-    for i in range(num_layers):
-        out_ref = baseline_ffn.forward(out_ref)
-
-    out_our = x
-    for i in range(num_layers):
-        out_our = our_ffn(out_our)
-
-    obj = (out_ref * (out_ref + 1)).square().mean()
-    grad_params_ref = torch.autograd.grad(obj, list(baseline_ffn.parameters()))
-
-    obj = (out_our * (out_our + 1)).square().mean()
-    grad_params_our = torch.autograd.grad(obj, list(our_ffn.parameters()))
-
-    for grad_ref, grad_our in zip(grad_params_ref, grad_params_our):
-        assert torch.allclose(grad_ref, grad_our)

+ 0 - 70
tests/test_rotary.py

@@ -1,70 +0,0 @@
-import torch
-
-from lib.modules.rotary import get_auxiliary_tensors, RotaryEmbeddings
-
-
-def test_rotary_embeddings():
-    batch_size = 11
-    seq_len = 50
-    nhead = 4
-    dim = 1024
-    dtype = torch.float32
-    device = torch.device('cpu')
-    base = 10 ** 4
-
-    torch.use_deterministic_algorithms(True)
-
-    # auxiliary tensors
-    a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
-    x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
-    assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
-    assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
-
-    # full layer outputs
-    ref_layer = Rotary(dim, base)
-    our_layer = RotaryEmbeddings(dim, base)
-    q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
-    k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
-
-    q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
-    q_our, k_our = our_layer(q), our_layer(k)
-    assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
-    assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
-
-    # check rotation equivariance of dot product
-    original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
-    rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
-    assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
-
-
-class Rotary(torch.nn.Module):
-    """ Reference implementation by ElutherAI """
-    def __init__(self, dim, base=10000):
-        super().__init__()
-        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
-        self.register_buffer("inv_freq", inv_freq)
-        self.seq_len_cached = None
-        self.cos_cached = None
-        self.sin_cached = None
-
-    def forward(self, x, seq_dim=1):
-        seq_len = x.shape[seq_dim]
-        if seq_len != self.seq_len_cached:
-            self.seq_len_cached = seq_len
-            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
-            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
-            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
-            self.cos_cached = emb.cos()[:, None, None, :]
-            self.sin_cached = emb.sin()[:, None, None, :]
-        return self.cos_cached, self.sin_cached
-
-
-def rotate_half(x):
-    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
-    return torch.cat(
-        (-x2, x1), dim=x1.ndim - 1
-    )  # dim=-1 triggers a bug in torch < 1.8.0
-
-
-def apply_rotary_pos_emb(q, k, cos, sin):
-    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)