Explorar o código

store inner optimizer states in a separate dict that will not be cleared if found inf/nan in local gradients

justheuristic %!s(int64=3) %!d(string=hai) anos
pai
achega
7ae28acecb
Modificáronse 1 ficheiros con 11 adicións e 5 borrados
  1. 11 5
      hivemind/optim/grad_scaler.py

+ 11 - 5
hivemind/optim/grad_scaler.py

@@ -35,6 +35,7 @@ class GradScaler(TorchGradScaler):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._is_ready_to_update = False
+        self._inner_optimizer_states = {}
         self._optimizer_states_to_reset = set()
         self._lock = threading.RLock()
 
@@ -52,7 +53,7 @@ class GradScaler(TorchGradScaler):
             assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             if self._is_running_global_step:
                 super().unscale_(optimizer)
-                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
                 return True
             else:
                 self._check_inf_per_device(optimizer)
@@ -62,14 +63,19 @@ class GradScaler(TorchGradScaler):
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
             # ^-- invoked privately within hivemind optimizer
+            inner_optimizer = optimizer
             with self._lock:
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step")
+
+                inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
+                if inner_optimizer_state is not None:
+                    self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
                 assert (
-                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
-                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
-                if self.are_grads_finite(optimizer, use_cached=True):
-                    super().step(optimizer, *args, **kwargs)
+                    self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
+                if self.are_grads_finite(inner_optimizer, use_cached=True):
+                    super().step(inner_optimizer, *args, **kwargs)
                 else:
                     logger.warning("Skipping global step due to gradient over/underflow")
                 self._is_ready_to_update = True