Selaa lähdekoodia

Use t5-small tokenizer

Aleksandr Borzunov 3 vuotta sitten
vanhempi
commit
c61c61b20d
3 muutettua tiedostoa jossa 4 lisäystä ja 4 poistoa
  1. 1 1
      arguments.py
  2. 1 1
      data.py
  3. 2 2
      task.py

+ 1 - 1
arguments.py

@@ -109,7 +109,7 @@ class CollaborativeArguments:
 class BasePeerArguments:
 class BasePeerArguments:
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     """Base arguments that are used for both trainers and for auxiliary peers such as training monitor"""
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
     experiment_prefix: str = field(default="my-model", metadata={"help": "A unique experiment name, used as prefix for all DHT keys"})
-    tokenizer_path: Optional[str] = field(default="gpt2", metadata={"help": "Path to the tokenizer"})
+    tokenizer_path: Optional[str] = field(default="t5-small", metadata={"help": "Path to the tokenizer"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
     cache_dir: Optional[str] = field(default="./cache", metadata={"help": "Path to the cache"})
 
 
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})
     authorize: bool = field(default=False, metadata={"help": "Whether or not to use HF authorizer"})

+ 1 - 1
data.py

@@ -22,7 +22,7 @@ def preprocess_batch(batch, tokenizer, max_sequence_length: int):
 
 
     if any(mask):
     if any(mask):
         result = tokenizer(list(itertools.compress(batch['caption'], mask)),
         result = tokenizer(list(itertools.compress(batch['caption'], mask)),
-                           truncation=True, max_length=max_sequence_length)
+                           add_special_tokens=False, max_length=max_sequence_length, truncation=True)
     else:
     else:
         # This branch is necessary because tokenizer([]) raises IndexError
         # This branch is necessary because tokenizer([]) raises IndexError
         result = {'input_ids': [], 'attention_mask': []}
         result = {'input_ids': [], 'attention_mask': []}

+ 2 - 2
task.py

@@ -9,7 +9,7 @@ import transformers
 from dalle_pytorch import DALLE
 from dalle_pytorch import DALLE
 from dalle_pytorch.vae import VQGanVAE, download
 from dalle_pytorch.vae import VQGanVAE, download
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
 from hivemind import SizeAdaptiveCompression, Float16Compression, Uniform8BitQuantization
-from transformers import DataCollatorWithPadding, GPT2TokenizerFast, get_linear_schedule_with_warmup
+from transformers import DataCollatorWithPadding, T5TokenizerFast, get_linear_schedule_with_warmup
 from torch import nn
 from torch import nn
 
 
 import utils
 import utils
@@ -49,7 +49,7 @@ class TrainingTask:
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         self.validators, self.local_public_key = utils.make_validators(self.peer_args.experiment_prefix)
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
         transformers.set_seed(trainer_args.seed)  # seed used for initialization
 
 
-        self.tokenizer = GPT2TokenizerFast.from_pretrained(peer_args.tokenizer_path)
+        self.tokenizer = T5TokenizerFast.from_pretrained(peer_args.tokenizer_path)
         self.tokenizer.pad_token = self.tokenizer.eos_token
         self.tokenizer.pad_token = self.tokenizer.eos_token
 
 
         output_dir = Path(trainer_args.output_dir)
         output_dir = Path(trainer_args.output_dir)