Sfoglia il codice sorgente

option to await a trigger

justheuristic 3 anni fa
parent
commit
1dd5cd0a60
1 ha cambiato i file con 14 aggiunte e 1 eliminazioni
  1. 14 1
      hivemind/optim/experimental/state_averager.py

+ 14 - 1
hivemind/optim/experimental/state_averager.py

@@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
 from threading import Event
+from types import NoneType
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
@@ -285,6 +286,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
+        wait_for_trigger: Optional[Callable[[], NoneType]] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
@@ -296,13 +298,15 @@ class TrainingStateAverager(DecentralizedAverager):
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
-        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         if delay_averaging is None:
@@ -317,6 +321,15 @@ class TrainingStateAverager(DecentralizedAverager):
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if wait_for_trigger is not None:
+            if not self.reuse_tensors or self.custom_gradients:
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is an advanced option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details."
+                )
         output = None
 
         if wait_for_delayed_update: