|
@@ -45,12 +45,9 @@ class GradScaler(TorchGradScaler):
|
|
|
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
- if self.are_grads_finite(optimizer):
|
|
|
- if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
|
|
|
- optimizer = optimizer.opt
|
|
|
- super().step(optimizer, *args, **kwargs)
|
|
|
- else:
|
|
|
- logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
+ if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
|
|
|
+ optimizer = optimizer.opt
|
|
|
+ super().step(optimizer, *args, **kwargs)
|
|
|
return True
|
|
|
else:
|
|
|
super().step(optimizer, *args, **kwargs)
|