|
@@ -5,7 +5,6 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
from contextlib import nullcontext
|
|
|
from itertools import chain
|
|
|
from threading import Event
|
|
|
-from types import NoneType
|
|
|
from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
import torch
|
|
@@ -286,7 +285,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
delay_optimizer_step: bool = False,
|
|
|
averaging_round: bool = False,
|
|
|
delay_averaging: Optional[bool] = None,
|
|
|
- wait_for_trigger: Optional[Callable[[], NoneType]] = None,
|
|
|
+ wait_for_trigger: Optional[Callable[[], Any]] = None,
|
|
|
grad_scaler: Optional[GradScaler] = None,
|
|
|
averaging_opts: Optional[Dict[str, Any]] = None,
|
|
|
):
|
|
@@ -306,6 +305,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
:param delay_averaging: if True, perform averaging in background and apply results in a future step
|
|
|
by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
|
|
|
:param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
|
|
|
+ :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
|
|
|
:param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
|
|
|
:param averaging_opts: a dict of keyword arguments forwarded into averaging round
|
|
|
"""
|
|
@@ -398,7 +398,7 @@ class TrainingStateAverager(DecentralizedAverager):
|
|
|
|
|
|
def _do(
|
|
|
self,
|
|
|
- wait_for_trigger: Optional[Callable[[], NoneType]],
|
|
|
+ wait_for_trigger: Optional[Callable[[], Any]],
|
|
|
optimizer_step: bool,
|
|
|
zero_grad: bool,
|
|
|
averaging_round: bool,
|