Prechádzať zdrojové kódy

Add an option to pre-schedule averaging (#398)

Part 1/3 for optimizer overhaul
- implement a way to pre-schedule averager step and add a trigger that manually enables averaging
- add tests for pre-scheduling and averaging trigger

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 3 rokov pred
rodič
commit
a09df5492f

+ 1 - 1
docs/conf.py

@@ -203,7 +203,7 @@ todo_include_todos = True
 
 
 def setup(app):
-    app.add_stylesheet("fix_rtd.css")
+    app.add_css_file("fix_rtd.css")
     app.add_config_value(
         "recommonmark_config",
         {

+ 92 - 62
hivemind/averaging/averager.py

@@ -16,6 +16,7 @@ import numpy as np
 import torch
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -28,7 +29,7 @@ from hivemind.compression import (
     serialize_torch_tensor,
 )
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
@@ -54,8 +55,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
@@ -63,7 +63,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
-    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
@@ -75,6 +75,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
+    :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
@@ -92,6 +93,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
 
     def __init__(
@@ -104,8 +107,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: int,
         min_group_size: int = 2,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        averaging_expiration: Optional[float] = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
@@ -116,6 +120,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         shutdown_timeout: float = 5,
@@ -129,6 +134,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+
         super().__init__()
         self.dht = dht
         self.prefix = prefix
@@ -163,8 +173,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time,
         )
         self.allreduce_kwargs = dict(
             compression=compression,
@@ -180,6 +190,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
         self.state_compression = state_compression
         self.tensor_infos = tensor_infos
 
@@ -250,6 +261,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
             except Exception as e:
@@ -294,7 +306,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def shutdown(self) -> None:
         """Shut down the averager process"""
         if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {}))  # shut down the daemon process
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
@@ -303,11 +315,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self, timeout: Optional[float]) -> None:
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
-        await asyncio.gather(*remaining_tasks)
+        await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
 
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
@@ -316,68 +328,81 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
         self,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
+        require_trigger: bool = False,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
+          By default, schedule all-reduce current time plus min_matchmaking_time seconds
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
-        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
+        deadline = get_dht_time() + timeout if timeout is not None else float("inf")
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
+        assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
+
+        user_data_for_gather = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            data_for_gather=data_for_gather,
         )
-        return future.result() if wait else future
 
-    async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
-    ):
-        start_time = get_dht_time()
+        future_for_trigger = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
+        step.attach_trigger(future_for_trigger.result())
 
+        if not require_trigger:
+            step.allow_allreduce()
+        return step.result() if wait else step
+
+    async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
         try:
-            while not future.done():
+            trigger = MPFuture()
+            step.attach_trigger(trigger)
+            future_for_trigger.set_result(trigger)
+
+            while not step.done():
                 try:
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    group_info = await self._matchmaking.look_for_group(step)
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
-                    future.set_result(
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
                             self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
                             ),
                             timeout=self._allreduce_timeout,
                         )
@@ -393,20 +418,20 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                 ) as e:
-                    time_elapsed = get_dht_time() - start_time
-                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                        logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                    if not step.allow_retries or get_dht_time() >= step.deadline:
+                        logger.exception(e)
+                        step.set_exception(e)
                     else:
-                        logger.warning(f"Averager caught {repr(e)}, retrying")
+                        logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."
@@ -416,8 +441,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
-            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -447,7 +472,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                            self.last_updated = get_dht_time()
+                            self._state_updated.set()
+
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
                             raise ValueError("aux peers should not receive averaged tensors")
@@ -477,7 +504,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
 
     @contextlib.asynccontextmanager
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
@@ -517,19 +543,24 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         while True:
             if self.allow_state_sharing:
+                self._state_updated.clear()
+                expiration_time = get_dht_time() + self.declare_state_period
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - self.request_timeout,
                     )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+            except asyncio.TimeoutError:
+                pass
 
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -584,10 +615,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         future = MPFuture()
-        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
         return future.result(timeout=timeout) if wait else future
 
-    async def _load_state_from_peers(self, future: MPFuture):
+    async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
@@ -611,7 +642,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
-                        async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -628,7 +659,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
                     except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")

+ 148 - 0
hivemind/averaging/control.py

@@ -0,0 +1,148 @@
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class AveragingStage(Enum):
+    IDLE = 0  # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4  # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+
+    :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
+    :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
+    :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
+    :param weight: averaging weight, can be changed afterwards
+    :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
+    """
+
+    # indices for the shared buffer
+    _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
+
+    def __init__(
+        self,
+        scheduled_time: DHTExpiration,
+        deadline: float,
+        allow_retries: bool,
+        weight: float,
+        data_for_gather: bytes,
+    ):
+        super().__init__()
+        self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
+        self._trigger: Optional[MPFuture] = None
+
+        # Buffer contents:
+        # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.began_allreduce = False
+
+    def attach_trigger(self, trigger: MPFuture):
+        assert self._trigger is None, "Trigger is already attached"
+        self._trigger = trigger
+
+    def allow_allreduce(self):
+        """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        if self._trigger.done():
+            logger.warning("Trigger is already set")
+        else:
+            self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        await self._trigger
+
+    @property
+    def triggered(self) -> bool:
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        return self._trigger.done()
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        if self.began_allreduce:
+            logger.warning("Changing scheduled time has no effect after all-reduce has already started")
+        if scheduled_time >= self.deadline:
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
+        struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert weight >= 0 and np.isfinite(weight)
+        if self.began_allreduce:
+            logger.warning("Changing weights has no effect after all-reduce has already started")
+        struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.can_modify = False
+        self._shared_buffer[StepControl._STAGE] = stage.value
+
+    @property
+    def began_allreduce(self) -> bool:
+        return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
+
+    @began_allreduce.setter
+    def began_allreduce(self, value: bool):
+        self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
+
+    @property
+    def data_for_gather(self) -> bytes:
+        return self._data_for_gather
+
+    @property
+    def deadline(self) -> DHTExpiration:
+        return self._deadline
+
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def __getstate__(self):
+        return dict(
+            super().__getstate__(),
+            _trigger=self._trigger,
+            _shared_buffer=self._shared_buffer,
+            immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
+        )
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
+        self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
+
+    def cancel(self) -> bool:
+        if self._trigger is not None:
+            self._trigger.cancel()
+        return self.cancel()

+ 53 - 44
hivemind/averaging/matchmaking.py

@@ -9,6 +9,9 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
+import numpy as np
+
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
@@ -43,16 +46,16 @@ class Matchmaking:
         prefix: str,
         target_group_size: int,
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         client_mode: bool,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
 
         super().__init__()
@@ -67,7 +70,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
 
         self.lock_looking_for_group = asyncio.Lock()
@@ -78,8 +81,16 @@ class Matchmaking:
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
+        self.step_control: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    async def looking_for_group(self, step_control: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step_control is None
+            self.step_control = step_control
+            yield
+            self.step_control = None
 
     @property
     def is_looking_for_group(self):
@@ -98,10 +109,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
@@ -110,11 +120,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
                 return None
 
@@ -136,16 +145,16 @@ class Matchmaking:
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        assert self.is_looking_for_group
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self._request_join_group(next_leader)
                     if group is not None:
                         return group
 
@@ -166,26 +175,25 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
+                request_expiration_time = self.get_request_expiration_time()
                 stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.step_control.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                     )
                 )
@@ -204,7 +212,7 @@ class Matchmaking:
                 return None
 
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -218,7 +226,7 @@ class Matchmaking:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
 
@@ -237,6 +245,14 @@ class Matchmaking:
             if stream is not None:
                 await stream.aclose()
 
+    def get_request_expiration_time(self) -> float:
+        """Returns the averager's current expiration time, which is used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step_control.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
+
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -352,7 +368,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step_control.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
         )
 
@@ -388,8 +404,8 @@ class Matchmaking:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -401,13 +417,13 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
             try:
                 yield self
@@ -467,20 +483,12 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         while get_dht_time() < self.search_end_time:
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             self.max_assured_time = max(
-                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
             )
 
             self.leader_queue.clear()
@@ -499,13 +507,14 @@ class PotentialLeaders:
             )
             self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
                     await self.running.wait()
-
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(
+                        min(max(step.scheduled_time, get_dht_time() + self.min_matchmaking_time), self.search_end_time)
+                    )
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()

+ 3 - 2
requirements-docs.txt

@@ -1,2 +1,3 @@
-recommonmark
-sphinx_rtd_theme
+recommonmark==0.5.0
+sphinx_rtd_theme==0.4.3
+sphinx==4.2.0

+ 48 - 0
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
+import time
 
 import numpy as np
 import pytest
@@ -7,6 +8,7 @@ import torch
 import hivemind
 import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
@@ -420,6 +422,52 @@ def test_getset_bits():
     assert averager.get_group_bits() == "00101011101010"
 
 
+@pytest.mark.forked
+def test_averaging_trigger():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            target_group_size=4,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            prefix="mygroup",
+            initial_group_bits="",
+            start=True,
+        )
+        for dht in launch_dht_instances(4)
+    )
+
+    controls = []
+    for i, averager in enumerate(averagers):
+        controls.append(
+            averager.step(
+                wait=False,
+                scheduled_time=hivemind.get_dht_time() + 0.5,
+                weight=1.0,
+                require_trigger=i in (1, 2),
+            )
+        )
+
+    time.sleep(0.6)
+
+    c0, c1, c2, c3 = controls
+    assert not any(c.done() for c in controls)
+    assert c0.stage == AveragingStage.RUNNING_ALLREDUCE
+    assert c1.stage == AveragingStage.AWAITING_TRIGGER
+    assert c2.stage == AveragingStage.AWAITING_TRIGGER
+    assert c3.stage == AveragingStage.RUNNING_ALLREDUCE
+
+    c1.allow_allreduce()
+    c2.allow_allreduce()
+    time.sleep(0.5)
+    assert all(c.stage == AveragingStage.FINISHED for c in controls)
+    assert all(c.done() for c in controls)
+
+    # check that setting trigger twice does not raise error
+    c0.allow_allreduce()
+
+
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)