浏览代码

better defaults

justheuristic 3 年之前
父节点
当前提交
49c14435ab
共有 2 个文件被更改,包括 11 次插入5 次删除
  1. 9 3
      benchmarks/benchmark_optimizer.py
  2. 2 2
      hivemind/optim/experimental/state_averager.py

+ 9 - 3
benchmarks/benchmark_optimizer.py

@@ -27,8 +27,9 @@ class TrainingArguments:
     num_clients: int = 3
     num_clients: int = 3
     target_batch_size: int = 128
     target_batch_size: int = 128
     reuse_grad_buffers: bool = True
     reuse_grad_buffers: bool = True
-    delay_optimizer_step: bool = False
-    use_amp: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    use_amp: bool = False
 
 
     lr_base: float = 0.1
     lr_base: float = 0.1
     lr_gamma: int = 0.1
     lr_gamma: int = 0.1
@@ -86,12 +87,13 @@ def benchmark_optimizer(args: TrainingArguments):
             matchmaking_time=args.matchmaking_time,
             matchmaking_time=args.matchmaking_time,
             averaging_timeout=args.averaging_timeout,
             averaging_timeout=args.averaging_timeout,
             reuse_grad_buffers=args.reuse_grad_buffers,
             reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
             delay_optimizer_step=args.delay_optimizer_step,
             delay_optimizer_step=args.delay_optimizer_step,
             client_mode=client_mode,
             client_mode=client_mode,
             verbose=verbose,
             verbose=verbose,
         )
         )
 
 
-        if args.reuse_grad_buffers:
+        if args.use_amp and args.reuse_grad_buffers:
             grad_scaler = hivemind.GradScaler()
             grad_scaler = hivemind.GradScaler()
         else:
         else:
             grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
             grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
@@ -152,3 +154,7 @@ def benchmark_optimizer(args: TrainingArguments):
     finally:
     finally:
         for peer in peers[1:]:
         for peer in peers[1:]:
             peer.kill()
             peer.kill()
+
+
+if __name__ == '__main__':
+    benchmark_optimizer(TrainingArguments())

+ 2 - 2
hivemind/optim/experimental/state_averager.py

@@ -322,12 +322,12 @@ class TrainingStateAverager(DecentralizedAverager):
         if averaging_opts and not averaging_round:
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
         if wait_for_trigger is not None:
         if wait_for_trigger is not None:
-            if not self.reuse_tensors or self.custom_gradients:
+            if not (self.reuse_tensors or self.custom_gradients):
                 # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
                 # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
                 # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
                 # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
                 # the possibly different gradients when wait_for_trigger has finished).
                 # the possibly different gradients when wait_for_trigger has finished).
                 raise ValueError(
                 raise ValueError(
-                    "wait_for_trigger is an advanced option that requires manual gradient manipulation. "
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
                     "If you know what you're doing, please refer to the comments in the source code for details."
                     "If you know what you're doing, please refer to the comments in the source code for details."
                 )
                 )
         output = None
         output = None