Michael Diskin 4 vuotta sitten
vanhempi
commit
043310148c
1 muutettua tiedostoa jossa 3 lisäystä ja 3 poistoa
  1. 3 3
      examples/albert/run_trainer.py

+ 3 - 3
examples/albert/run_trainer.py

@@ -11,7 +11,7 @@ import transformers
 from datasets import load_from_disk
 from torch.utils.data import DataLoader
 from torch_optimizer import Lamb
-from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, Adafactor
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, Adafactor, AdafactorSchedule
 from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
 from transformers.optimization import get_linear_schedule_with_warmup
 from transformers.trainer import Trainer
@@ -83,8 +83,8 @@ def get_optimizer_and_scheduler(training_args, model):
         weight_decay=training_args.weight_decay,
     )
 
-    scheduler = get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+    scheduler = AdafactorSchedule(
+        opt
     )
 
     return opt, scheduler