|
@@ -35,6 +35,7 @@ class GradScaler(TorchGradScaler):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self._is_running_global_step = False
|
|
|
self._is_ready_to_update = False
|
|
|
+ self._inner_optimizer_states = {}
|
|
|
self._optimizer_states_to_reset = set()
|
|
|
self._lock = threading.RLock()
|
|
|
|
|
@@ -52,7 +53,7 @@ class GradScaler(TorchGradScaler):
|
|
|
assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
if self._is_running_global_step:
|
|
|
super().unscale_(optimizer)
|
|
|
- self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
+ self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
|
|
|
return True
|
|
|
else:
|
|
|
self._check_inf_per_device(optimizer)
|
|
@@ -62,14 +63,19 @@ class GradScaler(TorchGradScaler):
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
|
|
|
# ^-- invoked privately within hivemind optimizer
|
|
|
+ inner_optimizer = optimizer
|
|
|
with self._lock:
|
|
|
if self._is_ready_to_update:
|
|
|
logger.warning("Please call grad_scaler.update() after each step")
|
|
|
+
|
|
|
+ inner_optimizer_state = self._inner_optimizer_states.pop(id(inner_optimizer), None)
|
|
|
+ if inner_optimizer_state is not None:
|
|
|
+ self._per_optimizer_states[id(inner_optimizer)] = inner_optimizer_state
|
|
|
assert (
|
|
|
- self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
|
|
|
- ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
- if self.are_grads_finite(optimizer, use_cached=True):
|
|
|
- super().step(optimizer, *args, **kwargs)
|
|
|
+ self._per_optimizer_states[id(inner_optimizer)]["stage"] == OptState.UNSCALED
|
|
|
+ ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step"
|
|
|
+ if self.are_grads_finite(inner_optimizer, use_cached=True):
|
|
|
+ super().step(inner_optimizer, *args, **kwargs)
|
|
|
else:
|
|
|
logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
self._is_ready_to_update = True
|