|
@@ -11,7 +11,7 @@ import transformers
|
|
|
from datasets import load_from_disk
|
|
|
from torch.utils.data import DataLoader
|
|
|
from torch_optimizer import Lamb
|
|
|
-from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed
|
|
|
+from transformers import DataCollatorForLanguageModeling, HfArgumentParser, TrainingArguments, set_seed, AdamW
|
|
|
from transformers.models.albert import AlbertConfig, AlbertForPreTraining, AlbertTokenizerFast
|
|
|
from transformers.optimization import get_linear_schedule_with_warmup
|
|
|
from transformers.trainer import Trainer
|
|
@@ -77,7 +77,7 @@ def get_optimizer_and_scheduler(training_args, model):
|
|
|
},
|
|
|
]
|
|
|
|
|
|
- opt = Lamb(
|
|
|
+ opt = AdamW(
|
|
|
optimizer_grouped_parameters,
|
|
|
lr=training_args.learning_rate,
|
|
|
betas=(training_args.adam_beta1, training_args.adam_beta2),
|