|
@@ -46,8 +46,9 @@ class GradScaler(TorchGradScaler):
|
|
|
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
- if self.are_grads_finite(optimizer):
|
|
|
- assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED
|
|
|
+ assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED, \
|
|
|
+ "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
+ if self.are_grads_finite(optimizer, use_cached=True):
|
|
|
super().step(optimizer.opt, *args, **kwargs)
|
|
|
else:
|
|
|
logger.warning("Skipping global step due to gradient over/underflow")
|
|
@@ -79,8 +80,9 @@ class GradScaler(TorchGradScaler):
|
|
|
# inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
|
|
|
return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
|
|
|
|
|
|
- def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
|
|
|
- return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
|
|
|
+ def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
|
|
|
+ opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
|
|
|
+ return not sum(v.item() for v in opt_dict.values())
|
|
|
|
|
|
|
|
|
class HivemindGradScaler(GradScaler):
|