|
@@ -7,7 +7,7 @@ import hivemind
|
|
|
import torch
|
|
|
import transformers
|
|
|
from dalle_pytorch import DALLE
|
|
|
-from dalle_pytorch.vae import VQGanVAE, download
|
|
|
+from dalle_pytorch.vae import VQGanVAE
|
|
|
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
|
|
|
from torch import nn
|
|
@@ -22,10 +22,15 @@ from lib.training.offload import OffloadOptimizer
|
|
|
|
|
|
logger = hivemind.get_logger(__name__)
|
|
|
|
|
|
-# VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
|
|
|
-# Note: If you change the URLs below, remove ./cache/* to clear the cache
|
|
|
-VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
|
|
|
-VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
|
|
|
+
|
|
|
+class VQGanParams(VQGanVAE):
|
|
|
+ def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
|
|
|
+ nn.Module.__init__(self)
|
|
|
+
|
|
|
+ self.num_layers = num_layers
|
|
|
+ self.image_size = image_size
|
|
|
+ self.num_tokens = num_tokens
|
|
|
+ self.is_gumbel = is_gumbel
|
|
|
|
|
|
|
|
|
class ModelWrapper(nn.Module):
|
|
@@ -59,18 +64,13 @@ class TrainingTask:
|
|
|
if latest_checkpoint_dir is None:
|
|
|
logger.info(f"Creating model")
|
|
|
|
|
|
- vae = VQGanVAE(
|
|
|
- vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
|
|
|
- vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
|
|
|
- )
|
|
|
-
|
|
|
depth = 64
|
|
|
attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
|
|
|
attn_types.append('conv_like')
|
|
|
shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
|
|
|
shared_layer_ids.append('w_conv')
|
|
|
dalle = DALLE(
|
|
|
- vae=vae,
|
|
|
+ vae=VQGanParams(),
|
|
|
num_text_tokens=self.tokenizer.vocab_size,
|
|
|
text_seq_len=trainer_args.text_seq_length,
|
|
|
dim=1024,
|