瀏覽代碼

actually unscale

justheuristic 3 年之前
父節點
當前提交
844f63e41e
共有 1 個文件被更改,包括 3 次插入6 次删除
  1. 3 6
      hivemind/optim/grad_scaler.py

+ 3 - 6
hivemind/optim/grad_scaler.py

@@ -45,12 +45,9 @@ class GradScaler(TorchGradScaler):
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
-            if self.are_grads_finite(optimizer):
-                if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
-                    optimizer = optimizer.opt
-                super().step(optimizer, *args, **kwargs)
-            else:
-                logger.warning("Skipping global step due to gradient over/underflow")
+            if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
+                optimizer = optimizer.opt
+            super().step(optimizer, *args, **kwargs)
             return True
         else:
             super().step(optimizer, *args, **kwargs)