Browse Source

decouple max steps from scheduler steps

justheuristic 3 years ago
parent
commit
6adf35a269

+ 2 - 1
examples/albert/arguments.py

@@ -127,7 +127,7 @@ class AlbertTrainingArguments(TrainingArguments):
     gradient_accumulation_steps: int = 2
     seq_length: int = 512
 
-    max_steps: int = 125_000  # please note: this affects both number of steps and learning rate schedule
+    total_steps: int = 125_000  # please note: this only affects the learning rate schedule
     learning_rate: float = 0.00176
     warmup_steps: int = 5000
     adam_epsilon: float = 1e-6
@@ -142,5 +142,6 @@ class AlbertTrainingArguments(TrainingArguments):
     logging_steps: int = 100
     save_total_limit: int = 2
     save_steps: int = 500
+    max_steps: int = 10 ** 30
 
     output_dir: str = "outputs"

+ 1 - 1
examples/albert/run_trainer.py

@@ -260,7 +260,7 @@ def main():
     ]
 
     scheduler = lambda opt: get_linear_schedule_with_warmup(
-        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.max_steps
+        opt, num_warmup_steps=training_args.warmup_steps, num_training_steps=training_args.total_steps
     )
 
     optimizer = Optimizer(

+ 2 - 1
examples/albert/run_training_monitor.py

@@ -9,7 +9,7 @@ import requests
 import torch
 import wandb
 from torch_optimizer import Lamb
-from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser
+from transformers import AlbertConfig, AlbertForPreTraining, HfArgumentParser, get_linear_schedule_with_warmup
 
 import hivemind
 from hivemind.optim.state_averager import TrainingStateAverager
@@ -99,6 +99,7 @@ class CheckpointHandler:
         self.state_averager = TrainingStateAverager(
             dht=dht,
             optimizer=opt,
+            scheduler=get_linear_schedule_with_warmup(opt, num_warmup_steps=5000, num_training_steps=125_000),
             prefix=f"{run_id}_state_averager",
             state_compression=hivemind.Float16Compression(),
             bandwidth=optimizer_args.bandwidth,