justheuristic 3 rokov pred
rodič
commit
59bcda18c7
1 zmenil súbory, kde vykonal 39 pridanie a 34 odobranie
  1. 39 34
      hivemind/optim/grad_scaler.py

+ 39 - 34
hivemind/optim/grad_scaler.py

@@ -1,6 +1,7 @@
 import contextlib
 from copy import deepcopy
 from typing import Dict, Optional
+import threading
 
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
@@ -25,6 +26,7 @@ class GradScaler(TorchGradScaler):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
 
     @contextlib.contextmanager
     def running_global_step(self):
@@ -35,46 +37,49 @@ class GradScaler(TorchGradScaler):
             self._is_running_global_step = was_running
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-        if self._is_running_global_step:
-            super().unscale_(optimizer)
-            self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
-            return True
-        else:
-            self._check_inf_per_device(optimizer)
-            self._optimizer_states_to_reset.add(id(optimizer))
-            return False
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                return True
+            else:
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
-        if self._is_running_global_step:
-            assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-            assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
-                "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
-            if self.are_grads_finite(optimizer, use_cached=True):
-                super().step(optimizer, *args, **kwargs)
+        with self._lock:
+            if self._is_running_global_step:
+                assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+                assert self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED, \
+                    "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+                if self.are_grads_finite(optimizer, use_cached=True):
+                    super().step(optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                return True
             else:
-                logger.warning("Skipping global step due to gradient over/underflow")
-            return True
-        else:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
-            super().step(optimizer)
-            self._optimizer_states_to_reset.add(id(optimizer))
-            return False
+                assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+                super().step(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
 
     def update(self, new_scale: Optional[float] = None) -> bool:
-        total_infs = 0
-        for optimizer_state in self._per_optimizer_states.values():
-            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
 
-        if self._is_running_global_step or total_infs != 0:
-            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
-            super().update(new_scale)
-            return True
-        else:
-            for opt_id in self._optimizer_states_to_reset:
-                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
-            self._optimizer_states_to_reset.clear()
-            return False
+            if self._is_running_global_step or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
 
     def _unscale_grads_(
         self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool