|
@@ -11,7 +11,7 @@ import transformers
|
|
|
from datasets import load_from_disk
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch_optimizer import Lamb
|
|
|
-from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, Adafactor
|
|
|
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, Adafactor, AdafactorSchedule
|
|
|
from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
|
|
|
from transformers.optimization import get_linear_schedule_with_warmup
|
|
|
from transformers.trainer import Trainer
|
|
@@ -83,8 +83,8 @@ def get_optimizer_and_scheduler(training_args, model):
|
|
|
weight_decay=training_args.weight_decay,
|
|
|
)
|
|
|
|
|
|
- scheduler = get_linear_schedule_with_warmup(
|
|
|
- opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
|
|
|
+ scheduler = AdafactorSchedule(
|
|
|
+ opt
|
|
|
)
|
|
|
|
|
|
return opt, scheduler
|