فهرست منبع

Use t5-small tokenizer

Aleksandr Borzunov 3 سال پیش
والد
کامیت
c61c61b20d
3فایلهای تغییر یافته به همراه4 افزوده شده و 4 حذف شده
  1. 1 1
      arguments.py
  2. 1 1
      data.py
  3. 2 2
      task.py

+ 1 - 1
arguments.py

@@ -109,7 +109,7 @@ class CollaborativeArguments:
 class BasePeerArguments:
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
-    tokenizer_path: Optional[str] = field(default="gpt2", metadata={"help": "Path to the tokenizer"})
+    tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
 
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})

+ 1 - 1
data.py

@@ -22,7 +22,7 @@ def preprocess_batch(batch, tokenizer, max_sequence_length: int):
 
     if any(mask):
         result = tokenizer(list(itertools.compress(batch['caption'], mask)),
-                           truncation=True, max_length=max_sequence_length)
+                           add_special_tokens=False, max_length=max_sequence_length, truncation=True)
     else:
         # This branch is necessary because tokenizer([]) raises IndexError
         result = {'input_ids': [], 'attention_mask': []}

+ 2 - 2
task.py

@@ -9,7 +9,7 @@ import transformers
 from dalle_pytorch import DALLE
 from dalle_pytorch.vae import VQGanVAE, download
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
-from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
+from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
 from torch import nn
 
 import utils
@@ -49,7 +49,7 @@ class TrainingTask:
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
 
-        self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
+        self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
         self.tokenizer.pad_token = self.tokenizer.eos_token
 
         output_dir = Path(trainer_args.output_dir)