justheuristic преди 3 години
родител
ревизия
6f063c6c5f
променени са 1 файла, в които са добавени 4 реда и са изтрити 2 реда
  1. 4 2
      hivemind/optim/grad_scaler.py

+ 4 - 2
hivemind/optim/grad_scaler.py

@@ -46,14 +46,16 @@ class GradScaler(TorchGradScaler):
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
-            assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED, \
+            assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
                 "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
             if self.are_grads_finite(optimizer, use_cached=True):
-                super().step(optimizer.opt, *args, **kwargs)
+                super().step(optimizer, *args, **kwargs)
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
             return True
         else:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False