Browse Source

actually unscale

justheuristic 3 years ago
parent
commit
6acfaff288
1 changed files with 4 additions and 4 deletions
  1. 4 4
      hivemind/optim/grad_scaler.py

+ 4 - 4
hivemind/optim/grad_scaler.py

@@ -46,13 +46,14 @@ class GradScaler(TorchGradScaler):
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
             if self.are_grads_finite(optimizer):
-                assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-                super().step(optimizer.opt, *args, **kwargs)
+                if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
+                    optimizer = optimizer.opt
+                super().step(optimizer, *args, **kwargs)
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
             return True
         else:
-            super().step(optimizer)
+            super().step(optimizer, *args, **kwargs)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
 
@@ -79,7 +80,6 @@ class GradScaler(TorchGradScaler):
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
     def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
         return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())