Selaa lähdekoodia

Don't download VQGAN weights

Aleksandr Borzunov 3 vuotta sitten
vanhempi
commit
3b184c57da
1 muutettua tiedostoa jossa 11 lisäystä ja 11 poistoa
  1. 11 11
      task.py

+ 11 - 11
task.py

@@ -7,7 +7,7 @@ import hivemind
 import torch
 import transformers
 from dalle_pytorch import DALLE
-from dalle_pytorch.vae import VQGanVAE, download
+from dalle_pytorch.vae import VQGanVAE
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
 from torch import nn
@@ -22,10 +22,15 @@ from lib.training.offload import OffloadOptimizer
 
 logger = hivemind.get_logger(__name__)
 
-# VQGAN with downsampling factor f=8, 8192 codebook entries, and Gumbel quantization
-# Note: If you change the URLs below, remove ./cache/* to clear the cache
-VQGAN_VAE_PATH = 'https://heibox.uni-heidelberg.de/f/34a747d5765840b5a99d/?dl=1'
-VQGAN_VAE_CONFIG_PATH = 'https://heibox.uni-heidelberg.de/f/b24d14998a8d4f19a34f/?dl=1'
+
+class VQGanParams(VQGanVAE):
+    def __init__(self, *, num_layers=3, image_size=256, num_tokens=8192, is_gumbel=True):
+        nn.Module.__init__(self)
+
+        self.num_layers = num_layers
+        self.image_size = image_size
+        self.num_tokens = num_tokens
+        self.is_gumbel = is_gumbel
 
 
 class ModelWrapper(nn.Module):
@@ -59,18 +64,13 @@ class TrainingTask:
         if latest_checkpoint_dir is None:
             logger.info(f"Creating model")
 
-            vae = VQGanVAE(
-                vqgan_model_path=download(VQGAN_VAE_PATH, 'vqgan.ckpt', root=peer_args.cache_dir),
-                vqgan_config_path=download(VQGAN_VAE_CONFIG_PATH, 'vqgan_config.yaml', root=peer_args.cache_dir),
-            )
-
             depth = 64
             attn_types = list(islice(cycle(['axial_row', 'axial_col', 'axial_row', 'axial_row']), depth - 1))
             attn_types.append('conv_like')
             shared_layer_ids = list(islice(cycle(range(4)), depth - 1))
             shared_layer_ids.append('w_conv')
             dalle = DALLE(
-                vae=vae,
+                vae=VQGanParams(),
                 num_text_tokens=self.tokenizer.vocab_size,
                 text_seq_len=trainer_args.text_seq_length,
                 dim=1024,