Browse Source

lock tensors

justheuristic 3 years ago
parent
commit
3fe38095a4
1 changed files with 9 additions and 6 deletions
  1. 9 6
      hivemind/optim/experimental/state_averager.py

+ 9 - 6
hivemind/optim/experimental/state_averager.py

@@ -2,6 +2,7 @@
 import logging
 import logging
 from asyncio import Future
 from asyncio import Future
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
 from threading import Event
 from threading import Event
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
@@ -393,12 +394,14 @@ class TrainingStateAverager(DecentralizedAverager):
         """
         """
         try:
         try:
             if optimizer_step:
             if optimizer_step:
-                logger.log(self.status_loglevel, f"Running optimizer step")
-                if grad_scaler is None:
-                    self.optimizer.step()
-                else:
-                    with grad_scaler.running_global_step():
-                        assert grad_scaler.step(self.optimizer)
+                with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
+                    logger.log(self.status_loglevel, f"Running optimizer step")
+                    if grad_scaler is None:
+                        self.optimizer.step()
+                    else:
+                        with grad_scaler.running_global_step():
+                            assert grad_scaler.step(self.optimizer)
+
             self._update_scheduler()
             self._update_scheduler()
 
 
             if zero_grad:
             if zero_grad: