justheuristic 3 роки тому
батько
коміт
7cd913df25
1 змінених файлів з 2 додано та 3 видалено
  1. 2 3
      hivemind/optim/experimental/optimizer.py

+ 2 - 3
hivemind/optim/experimental/optimizer.py

@@ -191,7 +191,7 @@ class Optimizer(torch.optim.Optimizer):
         self,
         closure: Optional[Callable[[], torch.Tensor]] = None,
         batch_size: Optional[int] = None,
-        grad_scaler: Optional[HivemindGradScaler] = None,
+        grad_scaler: Optional[GradScaler] = None,
         **kwargs,
     ):
         """
@@ -245,7 +245,7 @@ class Optimizer(torch.optim.Optimizer):
             # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self.opt)
+                    assert grad_scaler.unscale_(self)
 
             if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
                 logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
@@ -351,7 +351,6 @@ class Optimizer(torch.optim.Optimizer):
 
     @property
     def opt(self) -> TorchOptimizer:
-        # for compatibility with HivemindGradScaler
         return self.state_averager.optimizer
 
     @property