浏览代码

option to await a trigger

justheuristic 3 年之前
父节点
当前提交
581155ad63
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 3 3
      hivemind/optim/experimental/state_averager.py

+ 3 - 3
hivemind/optim/experimental/state_averager.py

@@ -5,7 +5,6 @@ from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
 from threading import Event
 from threading import Event
-from types import NoneType
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
@@ -286,7 +285,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
         delay_averaging: Optional[bool] = None,
-        wait_for_trigger: Optional[Callable[[], NoneType]] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
     ):
@@ -306,6 +305,7 @@ class TrainingStateAverager(DecentralizedAverager):
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         """
@@ -398,7 +398,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
 
     def _do(
     def _do(
         self,
         self,
-        wait_for_trigger: Optional[Callable[[], NoneType]],
+        wait_for_trigger: Optional[Callable[[], Any]],
         optimizer_step: bool,
         optimizer_step: bool,
         zero_grad: bool,
         zero_grad: bool,
         averaging_round: bool,
         averaging_round: bool,