justheuristic 3 gadi atpakaļ
vecāks
revīzija
6f4d16430d

+ 0 - 1
hivemind/optim/experimental/state_averager.py

@@ -1,5 +1,4 @@
 """ An extension of averager that supports common optimization use cases. """
-from __future__ import annotations
 import logging
 from asyncio import Future
 from concurrent.futures import ThreadPoolExecutor

+ 2 - 3
hivemind/optim/grad_scaler.py

@@ -6,8 +6,7 @@ from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
 from torch.optim import Optimizer as TorchOptimizer
 
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.experimental.optimizer import Optimizer
+import hivemind
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -35,7 +34,7 @@ class GradScaler(TorchGradScaler):
             self._is_running_global_step = was_running
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, (Optimizer, DecentralizedOptimizerBase))
+        assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
         if self._is_running_global_step:
             super().unscale_(optimizer.opt)
             return True