|
@@ -6,7 +6,7 @@ from torch.cuda.amp import GradScaler as TorchGradScaler
|
|
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
|
|
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
|
|
from torch.optim import Optimizer as TorchOptimizer
|
|
from torch.optim import Optimizer as TorchOptimizer
|
|
|
|
|
|
-from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
|
|
|
+from hivemind.optim import DecentralizedOptimizerBase, Optimizer
|
|
from hivemind.utils.logging import get_logger
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
@@ -34,7 +34,7 @@ class GradScaler(TorchGradScaler):
|
|
self._is_running_global_step = was_running
|
|
self._is_running_global_step = was_running
|
|
|
|
|
|
def unscale_(self, optimizer: TorchOptimizer) -> bool:
|
|
def unscale_(self, optimizer: TorchOptimizer) -> bool:
|
|
- assert isinstance(optimizer, DecentralizedOptimizerBase)
|
|
|
|
|
|
+ assert isinstance(optimizer, (Optimizer, DecentralizedOptimizerBase))
|
|
if self._is_running_global_step:
|
|
if self._is_running_global_step:
|
|
super().unscale_(optimizer.opt)
|
|
super().unscale_(optimizer.opt)
|
|
return True
|
|
return True
|