|
@@ -25,6 +25,7 @@ class GradScaler(TorchGradScaler):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
self._is_running_global_step = False
|
|
|
+ self._is_ready_to_update = False
|
|
|
self._optimizer_states_to_reset = set()
|
|
|
self._lock = threading.RLock()
|
|
|
|
|
@@ -52,6 +53,8 @@ class GradScaler(TorchGradScaler):
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
with self._lock:
|
|
|
+ if not self._is_ready_to_update:
|
|
|
+ logger.warning("Please call grad_scaler.update() after each step.")
|
|
|
assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|
|
|
"InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
@@ -59,6 +62,7 @@ class GradScaler(TorchGradScaler):
|
|
|
super().step(optimizer, *args, **kwargs)
|
|
|
else:
|
|
|
logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
+ self._is_ready_to_update = True
|
|
|
return True
|
|
|
else:
|
|
|
assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
@@ -72,9 +76,10 @@ class GradScaler(TorchGradScaler):
|
|
|
for optimizer_state in self._per_optimizer_states.values():
|
|
|
total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
|
|
|
|
|
|
- if self._is_running_global_step or total_infs != 0:
|
|
|
+ if self._is_ready_to_update or total_infs != 0:
|
|
|
# note: we update either during actual optimizer step or if we need to reduce scale due to NaN
|
|
|
super().update(new_scale)
|
|
|
+ self._is_ready_to_update = False
|
|
|
return True
|
|
|
else:
|
|
|
for opt_id in self._optimizer_states_to_reset:
|