|
@@ -2,6 +2,7 @@
|
|
|
import logging
|
|
|
from asyncio import Future
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
+from contextlib import nullcontext
|
|
|
from itertools import chain
|
|
|
from threading import Event
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
@@ -393,12 +394,14 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
"""
|
|
|
try:
|
|
|
if optimizer_step:
|
|
|
- logger.log(self.status_loglevel, f"Running optimizer step")
|
|
|
- if grad_scaler is None:
|
|
|
- self.optimizer.step()
|
|
|
- else:
|
|
|
- with grad_scaler.running_global_step():
|
|
|
- assert grad_scaler.step(self.optimizer)
|
|
|
+ with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
|
|
|
+ logger.log(self.status_loglevel, f"Running optimizer step")
|
|
|
+ if grad_scaler is None:
|
|
|
+ self.optimizer.step()
|
|
|
+ else:
|
|
|
+ with grad_scaler.running_global_step():
|
|
|
+ assert grad_scaler.step(self.optimizer)
|
|
|
+
|
|
|
self._update_scheduler()
|
|
|
|
|
|
if zero_grad:
|