|
@@ -9,7 +9,7 @@ import requests
|
|
|
import torch
|
|
|
import wandb
|
|
|
from torch_optimizer import Lamb
|
|
|
-from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
|
|
|
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind.optim.state_averager import TrainingStateAverager
|
|
@@ -99,6 +99,7 @@ class CheckpointHandler:
|
|
|
self.state_averager = TrainingStateAverager(
|
|
|
dht=dht,
|
|
|
optimizer=opt,
|
|
|
+ scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
|
|
|
prefix=f"{run_id}_state_averager",
|
|
|
state_compression=hivemind.Float16Compression(),
|
|
|
bandwidth=optimizer_args.bandwidth,
|