Browse Source

better defaults

justheuristic 3 years ago
parent
commit
49c14435ab
2 changed files with 11 additions and 5 deletions
  1. 9 3
      benchmarks/benchmark_optimizer.py
  2. 2 2
      hivemind/optim/experimental/state_averager.py

+ 9 - 3
benchmarks/benchmark_optimizer.py

@@ -27,8 +27,9 @@ class TrainingArguments:
     num_clients: int = 3
     target_batch_size: int = 128
     reuse_grad_buffers: bool = True
-    delay_optimizer_step: bool = False
-    use_amp: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    use_amp: bool = False
 
     lr_base: float = 0.1
     lr_gamma: int = 0.1
@@ -86,12 +87,13 @@ def benchmark_optimizer(args: TrainingArguments):
             matchmaking_time=args.matchmaking_time,
             averaging_timeout=args.averaging_timeout,
             reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
             delay_optimizer_step=args.delay_optimizer_step,
             client_mode=client_mode,
             verbose=verbose,
         )
 
-        if args.reuse_grad_buffers:
+        if args.use_amp and args.reuse_grad_buffers:
             grad_scaler = hivemind.GradScaler()
         else:
             grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
@@ -152,3 +154,7 @@ def benchmark_optimizer(args: TrainingArguments):
     finally:
         for peer in peers[1:]:
             peer.kill()
+
+
+if __name__ == '__main__':
+    benchmark_optimizer(TrainingArguments())

+ 2 - 2
hivemind/optim/experimental/state_averager.py

@@ -322,12 +322,12 @@ class TrainingStateAverager(DecentralizedAverager):
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
         if wait_for_trigger is not None:
-            if not self.reuse_tensors or self.custom_gradients:
+            if not (self.reuse_tensors or self.custom_gradients):
                 # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
                 # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
                 # the possibly different gradients when wait_for_trigger has finished).
                 raise ValueError(
-                    "wait_for_trigger is an advanced option that requires manual gradient manipulation. "
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
                     "If you know what you're doing, please refer to the comments in the source code for details."
                 )
         output = None