|
@@ -5,6 +5,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
from contextlib import nullcontext
|
|
|
from itertools import chain
|
|
|
from threading import Event
|
|
|
+from types import NoneType
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
@@ -285,6 +286,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delay_optimizer_step: bool = False,
|
|
|
averaging_round: bool = False,
|
|
|
delay_averaging: Optional[bool] = None,
|
|
|
+ wait_for_trigger: Optional[Callable[[], NoneType]] = None,
|
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
@@ -296,13 +298,15 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
|
|
|
:param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
|
|
|
:param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
|
|
|
+ :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
|
|
|
:param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
|
|
|
:param zero_grad: if True, reset local gradients after performing optimizer step
|
|
|
:param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
|
|
|
:param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
|
|
|
- :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
:param delay_averaging: if True, perform averaging in background and apply results in a future step
|
|
|
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
|
|
|
+ :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
|
|
|
+ :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
|
"""
|
|
|
if delay_averaging is None:
|
|
@@ -317,6 +321,15 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
|
|
|
if averaging_opts and not averaging_round:
|
|
|
logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
|
|
|
+ if wait_for_trigger is not None:
|
|
|
+ if not self.reuse_tensors or self.custom_gradients:
|
|
|
+ # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
|
|
|
+ # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
|
|
|
+ # the possibly different gradients when wait_for_trigger has finished).
|
|
|
+ raise ValueError(
|
|
|
+ "wait_for_trigger is an advanced option that requires manual gradient manipulation. "
|
|
|
+ "If you know what you're doing, please refer to the comments in the source code for details."
|
|
|
+ )
|
|
|
output = None
|
|
|
|
|
|
if wait_for_delayed_update:
|