Bläddra i källkod

update GradSCaler

justheuristic 3 år sedan
förälder
incheckning
86326064de
2 ändrade filer med 3 tillägg och 3 borttagningar
  1. 1 1
      benchmarks/benchmark_optimizer.py
  2. 2 2
      hivemind/optim/grad_scaler.py

+ 1 - 1
benchmarks/benchmark_optimizer.py

@@ -51,7 +51,7 @@ class TrainingArguments:
     )
 
 
-def _run_training_with_swarm(args: TrainingArguments):
+def benchmark_optimizer(args: TrainingArguments):
     random.seed(args.seed)
     torch.manual_seed(args.seed)
     torch.set_num_threads(1)

+ 2 - 2
hivemind/optim/grad_scaler.py

@@ -6,7 +6,7 @@ from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
 from torch.optim import Optimizer as TorchOptimizer
 
-from hivemind.optim.base import DecentralizedOptimizerBase
+from hivemind.optim import DecentralizedOptimizerBase, Optimizer
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -34,7 +34,7 @@ class GradScaler(TorchGradScaler):
             self._is_running_global_step = was_running
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
+        assert isinstance(optimizer, (Optimizer, DecentralizedOptimizerBase))
         if self._is_running_global_step:
             super().unscale_(optimizer.opt)
             return True