@@ -17,6 +17,7 @@ from hivemind.optim import (
DecentralizedOptimizerBase,
DecentralizedSGD,
TrainingAverager,
+ GradScaler,
Optimizer
)
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
@@ -6,7 +6,8 @@ from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
from torch.optim import Optimizer as TorchOptimizer
-from hivemind.optim import DecentralizedOptimizerBase, Optimizer
+from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.experimental.optimizer import Optimizer
from hivemind.utils.logging import get_logger
logger = get_logger(__name__)