justheuristic 3 vuotta sitten
vanhempi
commit
3fe38095a4
1 muutettua tiedostoa jossa 9 lisäystä ja 6 poistoa
  1. 9 6
      hivemind/optim/experimental/state_averager.py

+ 9 - 6
hivemind/optim/experimental/state_averager.py

@@ -2,6 +2,7 @@
 import logging
 from asyncio import Future
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
 from threading import Event
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
@@ -393,12 +394,14 @@ class TrainingStateAverager(DecentralizedAverager):
         """
         try:
             if optimizer_step:
-                logger.log(self.status_loglevel, f"Running optimizer step")
-                if grad_scaler is None:
-                    self.optimizer.step()
-                else:
-                    with grad_scaler.running_global_step():
-                        assert grad_scaler.step(self.optimizer)
+                with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
+                    logger.log(self.status_loglevel, f"Running optimizer step")
+                    if grad_scaler is None:
+                        self.optimizer.step()
+                    else:
+                        with grad_scaler.running_global_step():
+                            assert grad_scaler.step(self.optimizer)
+
             self._update_scheduler()
 
             if zero_grad: