|
@@ -27,8 +27,9 @@ class TrainingArguments:
|
|
|
num_clients: int = 3
|
|
|
target_batch_size: int = 128
|
|
|
reuse_grad_buffers: bool = True
|
|
|
- delay_optimizer_step: bool = False
|
|
|
- use_amp: bool = True
|
|
|
+ delay_grad_averaging: bool = True
|
|
|
+ delay_optimizer_step: bool = True
|
|
|
+ use_amp: bool = False
|
|
|
|
|
|
lr_base: float = 0.1
|
|
|
lr_gamma: int = 0.1
|
|
@@ -86,12 +87,13 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
matchmaking_time=args.matchmaking_time,
|
|
|
averaging_timeout=args.averaging_timeout,
|
|
|
reuse_grad_buffers=args.reuse_grad_buffers,
|
|
|
+ delay_grad_averaging=args.delay_grad_averaging,
|
|
|
delay_optimizer_step=args.delay_optimizer_step,
|
|
|
client_mode=client_mode,
|
|
|
verbose=verbose,
|
|
|
)
|
|
|
|
|
|
- if args.reuse_grad_buffers:
|
|
|
+ if args.use_amp and args.reuse_grad_buffers:
|
|
|
grad_scaler = hivemind.GradScaler()
|
|
|
else:
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
|
@@ -152,3 +154,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
finally:
|
|
|
for peer in peers[1:]:
|
|
|
peer.kill()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ benchmark_optimizer(TrainingArguments())
|