浏览代码

undo scaler changes

justheuristic 3 年之前
父节点
当前提交
0f2cd9400a
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 3 3
      hivemind/optim/experimental/optimizer.py

+ 3 - 3
hivemind/optim/experimental/optimizer.py

@@ -19,7 +19,7 @@ from hivemind.optim.experimental.state_averager import (
     TorchOptimizer,
     TrainingStateAverager,
 )
-from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.grad_scaler import GradScaler
 from hivemind.utils import get_dht_time, get_logger
 
 logger = get_logger(__name__)
@@ -202,8 +202,8 @@ class Optimizer(torch.optim.Optimizer):
         :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
-        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
-            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler).")
         if self.batch_size_per_step is None and batch_size is None:
             raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step