|
@@ -98,6 +98,7 @@ def benchmark_optimizer(args: TrainingArguments):
|
|
|
if args.use_amp and args.reuse_grad_buffers:
|
|
|
grad_scaler = hivemind.GradScaler()
|
|
|
else:
|
|
|
+ # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
|
|
|
|
|
|
prev_time = time.perf_counter()
|