|
@@ -29,6 +29,7 @@ class TrainingArguments:
|
|
|
reuse_grad_buffers: bool = True
|
|
|
delay_grad_averaging: bool = True
|
|
|
delay_optimizer_step: bool = True
|
|
|
+ average_state_every: int = 3
|
|
|
use_amp: bool = False
|
|
|
|
|
|
lr_base: float = 0.1
|
|
@@ -89,6 +90,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
reuse_grad_buffers=args.reuse_grad_buffers,
|
|
|
delay_grad_averaging=args.delay_grad_averaging,
|
|
|
delay_optimizer_step=args.delay_optimizer_step,
|
|
|
+ average_state_every=args.average_state_every,
|
|
|
client_mode=client_mode,
|
|
|
verbose=verbose,
|
|
|
)
|