Aleksandr Borzunov 3 年之前
父节点
当前提交
b9e8db1ad2
共有 1 个文件被更改,包括 1 次插入0 次删除
  1. 1 0
      benchmarks/benchmark_optimizer.py

+ 1 - 0
benchmarks/benchmark_optimizer.py

@@ -98,6 +98,7 @@ def benchmark_optimizer(args: TrainingArguments):
         if args.use_amp and args.reuse_grad_buffers:
         if args.use_amp and args.reuse_grad_buffers:
             grad_scaler = hivemind.GradScaler()
             grad_scaler = hivemind.GradScaler()
         else:
         else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
             grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
             grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
 
 
         prev_time = time.perf_counter()
         prev_time = time.perf_counter()