|
@@ -19,7 +19,7 @@ from hivemind.optim.experimental.state_averager import (
|
|
|
TorchOptimizer,
|
|
|
TrainingStateAverager,
|
|
|
)
|
|
|
-from hivemind.optim.grad_scaler import HivemindGradScaler
|
|
|
+from hivemind.optim.grad_scaler import GradScaler
|
|
|
from hivemind.utils import get_dht_time, get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -202,8 +202,8 @@ class Optimizer(torch.optim.Optimizer):
|
|
|
:param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
|
|
|
:note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
|
|
|
"""
|
|
|
- if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
|
|
|
- raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
|
|
|
+ if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
|
|
|
+ raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler).")
|
|
|
if self.batch_size_per_step is None and batch_size is None:
|
|
|
raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
|
|
|
batch_size = batch_size if batch_size is not None else self.batch_size_per_step
|