Browse Source

option to await a trigger

justheuristic 3 years ago
parent
commit
1dd5cd0a60
1 changed files with 14 additions and 1 deletions
  1. 14 1
      hivemind/optim/experimental/state_averager.py

+ 14 - 1
hivemind/optim/experimental/state_averager.py

@@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
 from threading import Event
 from threading import Event
+from types import NoneType
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
@@ -285,6 +286,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
         delay_averaging: Optional[bool] = None,
+        wait_for_trigger: Optional[Callable[[], NoneType]] = None,
         grad_scaler: Optional[GradScaler] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
     ):
@@ -296,13 +298,15 @@ class TrainingStateAverager(DecentralizedAverager):
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
-        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         """
         if delay_averaging is None:
         if delay_averaging is None:
@@ -317,6 +321,15 @@ class TrainingStateAverager(DecentralizedAverager):
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if wait_for_trigger is not None:
+            if not self.reuse_tensors or self.custom_gradients:
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is an advanced option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details."
+                )
         output = None
         output = None
 
 
         if wait_for_delayed_update:
         if wait_for_delayed_update: