justheuristic 3 tahun lalu
induk
melakukan
94038bf4eb
1 mengubah file dengan 7 tambahan dan 7 penghapusan
  1. 7 7
      hivemind/optim/grad_scaler.py

+ 7 - 7
hivemind/optim/grad_scaler.py

@@ -49,8 +49,8 @@ class GradScaler(TorchGradScaler):
                 return False
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
-        with self._lock:
-            if self._is_running_global_step:
+        if self._is_running_global_step:
+            with self._lock:
                 assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
                 assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
                     "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
@@ -59,11 +59,11 @@ class GradScaler(TorchGradScaler):
                 else:
                     logger.warning("Skipping global step due to gradient over/underflow")
                 return True
-            else:
-                assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-                super().step(optimizer)
-                self._optimizer_states_to_reset.add(id(optimizer))
-                return False
+        else:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            super().step(optimizer)
+            self._optimizer_states_to_reset.add(id(optimizer))
+            return False
 
     def update(self, new_scale: Optional[float] = None) -> bool:
         with self._lock: