Эх сурвалжийг харах

Add an option to pre-schedule averaging (#398)

Part 1/3 for optimizer overhaul
- implement a way to pre-schedule averager step and add a trigger that manually enables averaging
- add tests for pre-scheduling and averaging trigger

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 3 жил өмнө
parent
commit
a09df5492f

+ 1 - 1
docs/conf.py

@@ -203,7 +203,7 @@ todo_include_todos = True
 
 
 
 
 def setup(app):
 def setup(app):
-    app.add_stylesheet("fix_rtd.css")
+    app.add_css_file("fix_rtd.css")
     app.add_config_value(
     app.add_config_value(
         "recommonmark_config",
         "recommonmark_config",
         {
         {

+ 92 - 62
hivemind/averaging/averager.py

@@ -16,6 +16,7 @@ import numpy as np
 import torch
 import torch
 
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -28,7 +29,7 @@ from hivemind.compression import (
     serialize_torch_tensor,
     serialize_torch_tensor,
 )
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
@@ -54,8 +55,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
@@ -63,7 +63,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
-    :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
+    :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
     :param bandwidth: if specified, this value represents the network bandwidth available to averager.
           By default, the averager is assumed to have the average bandwidth of his group.
           By default, the averager is assumed to have the average bandwidth of his group.
@@ -75,6 +75,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
           local tensors for averaging
           local tensors for averaging
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
       with averager.allow_state_sharing = True / False
+    :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
 
     Example:
     Example:
@@ -92,6 +93,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
 
 
     def __init__(
     def __init__(
@@ -104,8 +107,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: int,
         target_group_size: int,
         min_group_size: int = 2,
         min_group_size: int = 2,
         initial_group_bits: str = "",
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        averaging_expiration: Optional[float] = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
@@ -116,6 +120,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         daemon: bool = True,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
@@ -129,6 +134,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.prefix = prefix
         self.prefix = prefix
@@ -163,8 +173,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time,
         )
         )
         self.allreduce_kwargs = dict(
         self.allreduce_kwargs = dict(
             compression=compression,
             compression=compression,
@@ -180,6 +190,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
         self.state_compression = state_compression
         self.state_compression = state_compression
         self.tensor_infos = tensor_infos
         self.tensor_infos = tensor_infos
 
 
@@ -250,6 +261,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
 
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
             except Exception as e:
             except Exception as e:
@@ -294,7 +306,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """Shut down the averager process"""
         """Shut down the averager process"""
         if self.is_alive():
         if self.is_alive():
-            self._outer_pipe.send(("_shutdown", [None], {}))  # shut down the daemon process
+            self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {}))  # shut down the daemon process
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             self.join(self.shutdown_timeout)
             if self.is_alive():
             if self.is_alive():
@@ -303,11 +315,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         else:
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self, timeout: Optional[float]) -> None:
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
-        await asyncio.gather(*remaining_tasks)
+        await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
 
 
     def __del__(self):
     def __del__(self):
         if self._parent_pid == os.getpid() and self.is_alive():
         if self._parent_pid == os.getpid() and self.is_alive():
@@ -316,68 +328,81 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
     def step(
         self,
         self,
         gather: Optional[GatheredData] = None,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         allow_retries: bool = True,
+        require_trigger: bool = False,
         wait: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
+          By default, schedule all-reduce current time plus min_matchmaking_time seconds
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
           within the specified timeout
-        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
+        :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
         if self.mode == AveragingMode.AUX and weight is not None:
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
             weight = float(self.mode != AveragingMode.AUX)
+        deadline = get_dht_time() + timeout if timeout is not None else float("inf")
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
-
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
+        assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
+
+        user_data_for_gather = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
+        data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            data_for_gather=data_for_gather,
         )
         )
-        return future.result() if wait else future
 
 
-    async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
-    ):
-        start_time = get_dht_time()
+        future_for_trigger = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
+        step.attach_trigger(future_for_trigger.result())
 
 
+        if not require_trigger:
+            step.allow_allreduce()
+        return step.result() if wait else step
+
+    async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
         try:
         try:
-            while not future.done():
+            trigger = MPFuture()
+            step.attach_trigger(trigger)
+            future_for_trigger.set_result(trigger)
+
+            while not step.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
-                    data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    group_info = await self._matchmaking.look_for_group(step)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
 
-                    future.set_result(
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
                         await asyncio.wait_for(
                             self._run_allreduce(
                             self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
                             ),
                             ),
                             timeout=self._allreduce_timeout,
                             timeout=self._allreduce_timeout,
                         )
                         )
@@ -393,20 +418,20 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.InvalidStateError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                     P2PHandlerError,
                 ) as e:
                 ) as e:
-                    time_elapsed = get_dht_time() - start_time
-                    if not allow_retries or (timeout is not None and timeout < time_elapsed):
-                        logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                    if not step.allow_retries or get_dht_time() >= step.deadline:
+                        logger.exception(e)
+                        step.set_exception(e)
                     else:
                     else:
-                        logger.warning(f"Averager caught {repr(e)}, retrying")
+                        logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
 
         except BaseException as e:
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
             raise
         finally:
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."
                         " Please report this to hivemind issues."
@@ -416,8 +441,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
-            bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
-            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             modes = tuple(map(AveragingMode, mode_ids))
             modes = tuple(map(AveragingMode, mode_ids))
 
 
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
             # compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
@@ -447,7 +472,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                            self.last_updated = get_dht_time()
+                            self._state_updated.set()
+
                     else:
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
                         async for _ in allreduce:  # trigger all-reduce by iterating
                             raise ValueError("aux peers should not receive averaged tensors")
                             raise ValueError("aux peers should not receive averaged tensors")
@@ -477,7 +504,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         """
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
@@ -517,19 +543,24 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         while True:
         while True:
             if self.allow_state_sharing:
             if self.allow_state_sharing:
+                self._state_updated.clear()
+                expiration_time = get_dht_time() + self.declare_state_period
                 asyncio.create_task(
                 asyncio.create_task(
                     asyncio.wait_for(
                     asyncio.wait_for(
                         self.dht.store(
                         self.dht.store(
                             download_key,
                             download_key,
                             subkey=self.peer_id.to_bytes(),
                             subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            expiration_time=expiration_time,
                             return_future=True,
                             return_future=True,
                         ),
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - self.request_timeout,
                     )
                     )
                 )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+            except asyncio.TimeoutError:
+                pass
 
 
     async def rpc_download_state(
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -584,10 +615,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         The exact contents of both metadata and tensors are determined by get_current_state method
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         """
         future = MPFuture()
         future = MPFuture()
-        self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
+        self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
         return future.result(timeout=timeout) if wait else future
         return future.result(timeout=timeout) if wait else future
 
 
-    async def _load_state_from_peers(self, future: MPFuture):
+    async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
         try:
         try:
             key_manager = self._matchmaking.group_key_manager
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
@@ -611,7 +642,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
 
 
-                        async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
                             if message.metadata:
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -628,7 +659,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                         logger.info(f"Finished downloading state from {peer}")
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
                         return
                     except Exception as e:
                     except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")

+ 148 - 0
hivemind/averaging/control.py

@@ -0,0 +1,148 @@
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class AveragingStage(Enum):
+    IDLE = 0  # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4  # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+
+    :param scheduled_time: estimated time when averaging should begin. Will be used for scheduling
+    :param deadline: if averaging is still in progress at this time, it should be stopped due to TimeoutError
+    :param allow_retries: if True, allow running matchmaking and all-reduce again if previous attempt fails
+    :param weight: averaging weight, can be changed afterwards
+    :param data_for_gather: send this data to all peers in the next group and gather it from groupmates
+    """
+
+    # indices for the shared buffer
+    _SCHEDULED_TIME, _WEIGHT, _STAGE, _BEGAN_ALLREDUCE = slice(0, 8), slice(8, 16), 16, 17
+
+    def __init__(
+        self,
+        scheduled_time: DHTExpiration,
+        deadline: float,
+        allow_retries: bool,
+        weight: float,
+        data_for_gather: bytes,
+    ):
+        super().__init__()
+        self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
+        self._trigger: Optional[MPFuture] = None
+
+        # Buffer contents:
+        # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.began_allreduce = False
+
+    def attach_trigger(self, trigger: MPFuture):
+        assert self._trigger is None, "Trigger is already attached"
+        self._trigger = trigger
+
+    def allow_allreduce(self):
+        """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        if self._trigger.done():
+            logger.warning("Trigger is already set")
+        else:
+            self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        await self._trigger
+
+    @property
+    def triggered(self) -> bool:
+        assert self._trigger is not None, "StepControl does not have an attached trigger"
+        return self._trigger.done()
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        if self.began_allreduce:
+            logger.warning("Changing scheduled time has no effect after all-reduce has already started")
+        if scheduled_time >= self.deadline:
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
+        struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack("d", self._shared_buffer[StepControl._WEIGHT].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert weight >= 0 and np.isfinite(weight)
+        if self.began_allreduce:
+            logger.warning("Changing weights has no effect after all-reduce has already started")
+        struct.pack_into("d", self._shared_buffer[StepControl._WEIGHT].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._shared_buffer[StepControl._STAGE].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.can_modify = False
+        self._shared_buffer[StepControl._STAGE] = stage.value
+
+    @property
+    def began_allreduce(self) -> bool:
+        return bool(self._shared_buffer[StepControl._BEGAN_ALLREDUCE].item())
+
+    @began_allreduce.setter
+    def began_allreduce(self, value: bool):
+        self._shared_buffer[StepControl._BEGAN_ALLREDUCE] = int(value)
+
+    @property
+    def data_for_gather(self) -> bytes:
+        return self._data_for_gather
+
+    @property
+    def deadline(self) -> DHTExpiration:
+        return self._deadline
+
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def __getstate__(self):
+        return dict(
+            super().__getstate__(),
+            _trigger=self._trigger,
+            _shared_buffer=self._shared_buffer,
+            immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
+        )
+
+    def __setstate__(self, state):
+        super().__setstate__(state)
+        self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
+        self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
+
+    def cancel(self) -> bool:
+        if self._trigger is not None:
+            self._trigger.cancel()
+        return self.cancel()

+ 53 - 44
hivemind/averaging/matchmaking.py

@@ -9,6 +9,9 @@ import random
 from math import isfinite
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 
+import numpy as np
+
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
@@ -43,16 +46,16 @@ class Matchmaking:
         prefix: str,
         prefix: str,
         target_group_size: int,
         target_group_size: int,
         min_group_size: int,
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         request_timeout: float,
         client_mode: bool,
         client_mode: bool,
         initial_group_bits: str = "",
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without ."
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
             )
 
 
         super().__init__()
         super().__init__()
@@ -67,7 +70,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
         self.client_mode = client_mode
 
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_looking_for_group = asyncio.Lock()
@@ -78,8 +81,16 @@ class Matchmaking:
 
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
+        self.step_control: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    async def looking_for_group(self, step_control: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step_control is None
+            self.step_control = step_control
+            yield
+            self.step_control = None
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):
@@ -98,10 +109,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
         )
 
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
         """
@@ -110,11 +120,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
                 " the existing group is either assembled or disbanded."
             )
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
             except asyncio.TimeoutError:
                 return None
                 return None
 
 
@@ -136,16 +145,16 @@ class Matchmaking:
                 # note: the code above ensures that we send all followers away before creating new future
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        assert self.is_looking_for_group
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
             while True:
                 try:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self._request_join_group(next_leader)
                     if group is not None:
                     if group is not None:
                         return group
                         return group
 
 
@@ -166,26 +175,25 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                         self.assembled_group.set_exception(e)
                     raise e
                     raise e
 
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def _request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         """
         :param leader: request this peer to be your leader for allreduce
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
           The originally specified leader can disband group and redirect us to a different leader
         """
         """
         assert self.is_looking_for_group and self.current_leader is None
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
+                request_expiration_time = self.get_request_expiration_time()
                 stream = await leader_stub.rpc_join_group(
                 stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.step_control.data_for_gather,
                         group_key=self.group_key_manager.current_key,
                         group_key=self.group_key_manager.current_key,
                     )
                     )
                 )
                 )
@@ -204,7 +212,7 @@ class Matchmaking:
                 return None
                 return None
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -218,7 +226,7 @@ class Matchmaking:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         self.current_leader = None
                         await stream.aclose()
                         await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 logger.debug(f"{self} - leader disbanded group")
                 return None
                 return None
 
 
@@ -237,6 +245,14 @@ class Matchmaking:
             if stream is not None:
             if stream is not None:
                 await stream.aclose()
                 await stream.aclose()
 
 
+    def get_request_expiration_time(self) -> float:
+        """Returns the averager's current expiration time, which is used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step_control.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
+
     async def rpc_join_group(
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -352,7 +368,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
         random.shuffle(ordered_peer_ids)
 
 
         gathered = tuple(
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step_control.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
             for peer_id in ordered_peer_ids
         )
         )
 
 
@@ -388,8 +404,8 @@ class Matchmaking:
 class PotentialLeaders:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
     """An utility class that searches for averagers that could become our leaders"""
 
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -401,13 +417,13 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
         self.search_end_time = float("inf")
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
         async with self.lock_search:
             self.running.set()
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
 
             try:
             try:
                 yield self
                 yield self
@@ -467,20 +483,12 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
             return maybe_next_leader
 
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         while get_dht_time() < self.search_end_time:
         while get_dht_time() < self.search_end_time:
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             self.max_assured_time = max(
             self.max_assured_time = max(
-                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
             )
             )
 
 
             self.leader_queue.clear()
             self.leader_queue.clear()
@@ -499,13 +507,14 @@ class PotentialLeaders:
             )
             )
             self.update_triggered.clear()
             self.update_triggered.clear()
 
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
         async with self.lock_declare:
             try:
             try:
                 while True:
                 while True:
                     await self.running.wait()
                     await self.running.wait()
-
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(
+                        min(max(step.scheduled_time, get_dht_time() + self.min_matchmaking_time), self.search_end_time)
+                    )
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
                     self.declared_expiration.set()

+ 3 - 2
requirements-docs.txt

@@ -1,2 +1,3 @@
-recommonmark
-sphinx_rtd_theme
+recommonmark==0.5.0
+sphinx_rtd_theme==0.4.3
+sphinx==4.2.0

+ 48 - 0
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
 import random
+import time
 
 
 import numpy as np
 import numpy as np
 import pytest
 import pytest
@@ -7,6 +8,7 @@ import torch
 import hivemind
 import hivemind
 import hivemind.averaging.averager
 import hivemind.averaging.averager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.partition import AllreduceException
 from hivemind.averaging.partition import AllreduceException
@@ -420,6 +422,52 @@ def test_getset_bits():
     assert averager.get_group_bits() == "00101011101010"
     assert averager.get_group_bits() == "00101011101010"
 
 
 
 
+@pytest.mark.forked
+def test_averaging_trigger():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            target_group_size=4,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            prefix="mygroup",
+            initial_group_bits="",
+            start=True,
+        )
+        for dht in launch_dht_instances(4)
+    )
+
+    controls = []
+    for i, averager in enumerate(averagers):
+        controls.append(
+            averager.step(
+                wait=False,
+                scheduled_time=hivemind.get_dht_time() + 0.5,
+                weight=1.0,
+                require_trigger=i in (1, 2),
+            )
+        )
+
+    time.sleep(0.6)
+
+    c0, c1, c2, c3 = controls
+    assert not any(c.done() for c in controls)
+    assert c0.stage == AveragingStage.RUNNING_ALLREDUCE
+    assert c1.stage == AveragingStage.AWAITING_TRIGGER
+    assert c2.stage == AveragingStage.AWAITING_TRIGGER
+    assert c3.stage == AveragingStage.RUNNING_ALLREDUCE
+
+    c1.allow_allreduce()
+    c2.allow_allreduce()
+    time.sleep(0.5)
+    assert all(c.stage == AveragingStage.FINISHED for c in controls)
+    assert all(c.done() for c in controls)
+
+    # check that setting trigger twice does not raise error
+    c0.allow_allreduce()
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
     torch.manual_seed(42)