|
@@ -9,7 +9,7 @@ import transformers
|
|
|
from dalle_pytorch import DALLE
|
|
|
from dalle_pytorch.vae import VQGanVAE, download
|
|
|
from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
|
|
|
-from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
|
|
|
+from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
|
|
|
from torch import nn
|
|
|
|
|
|
import utils
|
|
@@ -49,7 +49,7 @@ class TrainingTask:
|
|
|
self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
|
|
|
transformers.set_seed(trainer_args.seed) # seed used for initialization
|
|
|
|
|
|
- self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
|
|
|
+ self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
|
|
|
output_dir = Path(trainer_args.output_dir)
|