Переглянути джерело

option to await a trigger

justheuristic 3 роки тому
батько
коміт
581155ad63
1 змінених файлів з 3 додано та 3 видалено
  1. 3 3
      hivemind/optim/experimental/state_averager.py

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -5,7 +5,6 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
 from threading import Event
-from types import NoneType
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
@@ -286,7 +285,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
-        wait_for_trigger: Optional[Callable[[], NoneType]] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
@@ -306,6 +305,7 @@ class TrainingStateAverager(DecentralizedAverager):
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
@@ -398,7 +398,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
     def _do(
         self,
-        wait_for_trigger: Optional[Callable[[], NoneType]],
+        wait_for_trigger: Optional[Callable[[], Any]],
         optimizer_step: bool,
         zero_grad: bool,
         averaging_round: bool,