|
@@ -6,7 +6,8 @@ from torch.cuda.amp import GradScaler as TorchGradScaler
|
|
|
from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
-from hivemind import DecentralizedOptimizerBase, get_logger
|
|
|
+from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|