|
@@ -11,7 +11,6 @@ import torch
|
|
|
import hivemind
|
|
|
from hivemind.averaging import DecentralizedAverager
|
|
|
from hivemind.compression import CompressionInfo, TensorRole
|
|
|
-from hivemind.optim.grad_scaler import GradScaler
|
|
|
from hivemind.utils import get_logger, nested_flatten, nested_pack
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -284,7 +283,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delay_optimizer_step: bool = False,
|
|
|
averaging_round: bool = False,
|
|
|
delay_averaging: Optional[bool] = None,
|
|
|
- grad_scaler: Optional[GradScaler] = None,
|
|
|
+ grad_scaler: Optional[hivemind.GradScaler] = None,
|
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
|
"""
|
|
@@ -380,7 +379,12 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
return output
|
|
|
|
|
|
def _do(
|
|
|
- self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
|
|
|
+ self,
|
|
|
+ optimizer_step: bool,
|
|
|
+ zero_grad: bool,
|
|
|
+ averaging_round: bool,
|
|
|
+ grad_scaler: Optional[hivemind.GradScaler],
|
|
|
+ **kwargs,
|
|
|
):
|
|
|
"""
|
|
|
Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
|