|
@@ -8,9 +8,9 @@ from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence,
|
|
|
|
|
|
import torch
|
|
|
|
|
|
-import hivemind
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
|
+from hivemind.optim.grad_scaler import GradScaler
|
|
|
from hivemind.utils import get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -283,7 +283,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delay_optimizer_step: bool = False,
|
|
|
averaging_round: bool = False,
|
|
|
delay_averaging: Optional[bool] = None,
|
|
|
- grad_scaler: Optional[hivemind.GradScaler] = None,
|
|
|
+ grad_scaler: Optional[GradScaler] = None,
|
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
|
"""
|
|
@@ -383,7 +383,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
optimizer_step: bool,
|
|
|
zero_grad: bool,
|
|
|
averaging_round: bool,
|
|
|
- grad_scaler: Optional[hivemind.GradScaler],
|
|
|
+ grad_scaler: Optional[GradScaler],
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""
|