justheuristic 3 年之前
父節點
當前提交
b35b6f6d4c
共有 1 個文件被更改,包括 6 次插入5 次删除
  1. 6 5
      hivemind/optim/grad_scaler.py

+ 6 - 5
hivemind/optim/grad_scaler.py

@@ -30,11 +30,12 @@ class GradScaler(TorchGradScaler):
 
     @contextlib.contextmanager
     def running_global_step(self):
-        was_running, self._is_running_global_step = self._is_running_global_step, True
-        try:
-            yield
-        finally:
-            self._is_running_global_step = was_running
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
         with self._lock: