浏览代码

actually unscale

justheuristic 3 年之前
父节点
当前提交
4932739ca5
共有 1 个文件被更改,包括 6 次插入4 次删除
  1. 6 4
      hivemind/optim/grad_scaler.py

+ 6 - 4
hivemind/optim/grad_scaler.py

@@ -46,8 +46,9 @@ class GradScaler(TorchGradScaler):
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
-            if self.are_grads_finite(optimizer):
-                assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED
+            assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED, \
+                "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+            if self.are_grads_finite(optimizer, use_cached=True):
                 super().step(optimizer.opt, *args, **kwargs)
             else:
                 logger.warning("Skipping global step due to gradient over/underflow")
@@ -79,8 +80,9 @@ class GradScaler(TorchGradScaler):
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
-    def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
 
 
 class HivemindGradScaler(GradScaler):