|
@@ -53,7 +53,7 @@ class GradScaler(TorchGradScaler):
|
|
|
def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
|
|
|
if self._is_running_global_step:
|
|
|
with self._lock:
|
|
|
- if not self._is_ready_to_update:
|
|
|
+ if self._is_ready_to_update:
|
|
|
logger.warning("Please call grad_scaler.update() after each step.")
|
|
|
assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|