|
@@ -3,7 +3,7 @@ from typing import Dict, Optional
|
|
|
|
|
|
import torch
|
|
|
from torch.cuda.amp import GradScaler as TorchGradScaler
|
|
|
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
|
|
|
+from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
|
|
|
from torch.optim import Optimizer as TorchOptimizer
|
|
|
|
|
|
import hivemind
|
|
@@ -37,6 +37,7 @@ class GradScaler(TorchGradScaler):
|
|
|
assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
if self._is_running_global_step:
|
|
|
super().unscale_(optimizer)
|
|
|
+ self._per_optimizer_states[id(optimizer.opt)] = self._per_optimizer_states[id(optimizer)]
|
|
|
return True
|
|
|
else:
|
|
|
self._check_inf_per_device(optimizer)
|
|
@@ -45,12 +46,14 @@ class GradScaler(TorchGradScaler):
|
|
|
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
- if isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase)):
|
|
|
- optimizer = optimizer.opt
|
|
|
- super().step(optimizer, *args, **kwargs)
|
|
|
+ if self.are_grads_finite(optimizer):
|
|
|
+ assert self._per_optimizer_states[id(optimizer.opt)]["stage"] == OptState.UNSCALED
|
|
|
+ super().step(optimizer.opt, *args, **kwargs)
|
|
|
+ else:
|
|
|
+ logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
return True
|
|
|
else:
|
|
|
- super().step(optimizer, *args, **kwargs)
|
|
|
+ super().step(optimizer)
|
|
|
self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
return False
|
|
|
|