|
@@ -49,8 +49,8 @@ class GradScaler(TorchGradScaler):
|
|
|
return False
|
|
|
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
- with self._lock:
|
|
|
- if self._is_running_global_step:
|
|
|
+ if self._is_running_global_step:
|
|
|
+ with self._lock:
|
|
|
assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|
|
|
"InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
@@ -59,11 +59,11 @@ class GradScaler(TorchGradScaler):
|
|
|
else:
|
|
|
logger.warning("Skipping global step due to gradient over/underflow")
|
|
|
return True
|
|
|
- else:
|
|
|
- assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- super().step(optimizer)
|
|
|
- self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
- return False
|
|
|
+ else:
|
|
|
+ assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
+ super().step(optimizer)
|
|
|
+ self._optimizer_states_to_reset.add(id(optimizer))
|
|
|
+ return False
|
|
|
|
|
|
def update(self, new_scale: Optional[float] = None) -> bool:
|
|
|
with self._lock:
|