Browse Source

option to await a trigger

justheuristic 3 years ago
parent
commit
cac9c5bf84
1 changed files with 4 additions and 0 deletions
  1. 4 0
      hivemind/optim/experimental/state_averager.py

+ 4 - 0
hivemind/optim/experimental/state_averager.py

@@ -366,6 +366,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
             self.pending_update = self.step_executor.submit(
                 self._do,
+                wait_for_trigger,
                 optimizer_step,
                 zero_grad,
                 averaging_round,
@@ -397,6 +398,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
     def _do(
         self,
+        wait_for_trigger: Optional[Callable[[], NoneType]],
         optimizer_step: bool,
         zero_grad: bool,
         averaging_round: bool,
@@ -408,6 +410,8 @@ class TrainingStateAverager(DecentralizedAverager):
         This method is meant to be called in the background executor.
         """
         try:
+            if wait_for_trigger is not None:
+                wait_for_trigger()
             if optimizer_step:
                 with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
                     logger.log(self.status_loglevel, f"Running optimizer step")