|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
import logging
|
|
|
from dataclasses import dataclass
|
|
|
from threading import Event, Lock, Thread
|
|
|
-from typing import Dict, Iterator, Optional, Any, Callable
|
|
|
+from typing import Dict, Iterator, Optional
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
@@ -14,11 +14,10 @@ from hivemind.dht import DHT
|
|
|
from hivemind.dht.crypto import RSASignatureValidator
|
|
|
from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
|
|
|
from hivemind.optim.base import DecentralizedOptimizerBase
|
|
|
+from hivemind.optim.grad_scaler import HivemindGradScaler
|
|
|
from hivemind.optim.performance_ema import PerformanceEMA
|
|
|
from hivemind.utils import get_dht_time, get_logger
|
|
|
|
|
|
-from lib.staging.scaler import HivemindGradScaler
|
|
|
-
|
|
|
logger = get_logger(__name__)
|
|
|
LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
|
|
|
|