|
@@ -1,11 +1,11 @@
|
|
|
import contextlib
|
|
|
+import threading
|
|
|
from copy import deepcopy
|
|
|
from typing import Dict, Optional
|
|
|
-import threading
|
|
|
|
|
|
import torch
|
|
|
from torch.cuda.amp import GradScaler as TorchGradScaler
|
|
|
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state, OptState
|
|
|
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
|
|
|
from torch.optim import Optimizer as TorchOptimizer
|
|
|
|
|
|
import hivemind
|
|
@@ -56,8 +56,9 @@ class GradScaler(TorchGradScaler):
|
|
|
if self._is_ready_to_update:
|
|
|
logger.warning("Please call grad_scaler.update() after each step.")
|
|
|
assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
|
|
|
- assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
|
|
|
- "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
+ assert (
|
|
|
+ self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
|
|
|
+ ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
|
|
|
if self.are_grads_finite(optimizer, use_cached=True):
|
|
|
super().step(optimizer, *args, **kwargs)
|
|
|
else:
|