|
@@ -30,11 +30,12 @@ class GradScaler(TorchGradScaler):
|
|
|
|
|
|
@contextlib.contextmanager
|
|
@contextlib.contextmanager
|
|
def running_global_step(self):
|
|
def running_global_step(self):
|
|
- was_running, self._is_running_global_step = self._is_running_global_step, True
|
|
|
|
- try:
|
|
|
|
- yield
|
|
|
|
- finally:
|
|
|
|
- self._is_running_global_step = was_running
|
|
|
|
|
|
+ with self._lock:
|
|
|
|
+ was_running, self._is_running_global_step = self._is_running_global_step, True
|
|
|
|
+ try:
|
|
|
|
+ yield
|
|
|
|
+ finally:
|
|
|
|
+ self._is_running_global_step = was_running
|
|
|
|
|
|
def unscale_(self, optimizer: TorchOptimizer) -> bool:
|
|
def unscale_(self, optimizer: TorchOptimizer) -> bool:
|
|
with self._lock:
|
|
with self._lock:
|