浏览代码

defer update to user

justheuristic 3 年之前
父节点
当前提交
8c2febdba8
共有 2 个文件被更改,包括 6 次插入2 次删除
  1. 0 1
      hivemind/optim/experimental/optimizer.py
  2. 6 1
      hivemind/optim/grad_scaler.py

+ 0 - 1
hivemind/optim/experimental/optimizer.py

@@ -246,7 +246,6 @@ class Optimizer(torch.optim.Optimizer):
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                     assert grad_scaler.unscale_(self)
-                    assert grad_scaler.update()
 
             if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
                 logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")

+ 6 - 1
hivemind/optim/grad_scaler.py

@@ -25,6 +25,7 @@ class GradScaler(TorchGradScaler):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
+        self._is_ready_to_update = False
         self._optimizer_states_to_reset = set()
         self._lock = threading.RLock()
 
@@ -52,6 +53,8 @@ class GradScaler(TorchGradScaler):
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
             with self._lock:
+                if not self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step.")
                 assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
                 assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
                     "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
@@ -59,6 +62,7 @@ class GradScaler(TorchGradScaler):
                     super().step(optimizer, *args, **kwargs)
                 else:
                     logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
                 return True
         else:
             assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
@@ -72,9 +76,10 @@ class GradScaler(TorchGradScaler):
             for optimizer_state in self._per_optimizer_states.values():
                 total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
 
-            if self._is_running_global_step or total_infs != 0:
+            if self._is_ready_to_update or total_infs != 0:
                 # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
                 super().update(new_scale)
+                self._is_ready_to_update = False
                 return True
             else:
                 for opt_id in self._optimizer_states_to_reset: