justheuristic 3 ani în urmă
părinte
comite
33e4afe360

+ 1 - 1
hivemind/optim/__init__.py

@@ -1,5 +1,5 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD

+ 2 - 3
hivemind/optim/collaborative.py

@@ -3,7 +3,7 @@ from __future__ import annotations
 import logging
 from dataclasses import dataclass
 from threading import Event, Lock, Thread
-from typing import Dict, Iterator, Optional, Any, Callable
+from typing import Dict, Iterator, Optional
 
 import numpy as np
 import torch
@@ -14,11 +14,10 @@ from hivemind.dht import DHT
 from hivemind.dht.crypto import RSASignatureValidator
 from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
 from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim.grad_scaler import HivemindGradScaler
 from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.utils import get_dht_time, get_logger
 
-from lib.staging.scaler import HivemindGradScaler
-
 logger = get_logger(__name__)
 LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
 

+ 1 - 1
hivemind/optim/grad_scaler.py

@@ -2,11 +2,11 @@ import contextlib
 from typing import Dict
 
 import torch
-from hivemind import DecentralizedOptimizerBase, get_logger
 from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
 from torch.optim import Optimizer
 
+from hivemind import DecentralizedOptimizerBase, get_logger
 
 logger = get_logger(__name__)