|
@@ -46,13 +46,14 @@ class GradScaler(TorchGradScaler):
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
if self.are_grads_finite(optimizer):
|
|
|
- assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- super().step(optimizer.opt, *args, **kwargs)
|
|
|
+ if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
|
|
|
+ optimizer = optimizer.opt
|
|
|
+ super().step(optimizer, *args, **kwargs)
|
|
|
else:
|
|
|
logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
return True
|
|
|
else:
|
|
|
- super().step(optimizer)
|
|
|
+ super().step(optimizer, *args, **kwargs)
|
|
|
self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
return False
|
|
|
|
|
@@ -79,7 +80,6 @@ class GradScaler(TorchGradScaler):
|
|
|
return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
|
|
|
|
|
|
def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
|
|
|
- assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
|
|
|
|
|
|
|