justheuristic преди 3 години
родител
ревизия
1f29b7882c
променени са 3 файла, в които са добавени 98 реда и са изтрити 69 реда
  1. 43 20
      hivemind/averaging/averager.py
  2. 19 15
      hivemind/averaging/control.py
  3. 36 34
      hivemind/averaging/matchmaking.py

+ 43 - 20
hivemind/averaging/averager.py

@@ -29,7 +29,7 @@ from hivemind.compression import (
     serialize_torch_tensor,
     serialize_torch_tensor,
 )
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
@@ -55,8 +55,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
@@ -93,6 +92,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
 
 
     def __init__(
     def __init__(
@@ -105,8 +106,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: int,
         target_group_size: int,
         min_group_size: int = 2,
         min_group_size: int = 2,
         initial_group_bits: str = "",
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        averaging_expiration: float = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
@@ -117,6 +119,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         daemon: bool = True,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
@@ -130,6 +133,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed in v1.0.1, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+            del averaging_expiration
+
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.prefix = prefix
         self.prefix = prefix
@@ -164,8 +173,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time
         )
         )
         self.allreduce_kwargs = dict(
         self.allreduce_kwargs = dict(
             compression=compression,
             compression=compression,
@@ -181,6 +190,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
         self.state_compression = state_compression
         self.state_compression = state_compression
         self.tensor_infos = tensor_infos
         self.tensor_infos = tensor_infos
 
 
@@ -251,6 +261,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
 
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
             except Exception as e:
             except Exception as e:
@@ -341,18 +352,23 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if self.mode == AveragingMode.AUX and weight is not None:
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
         if scheduled_time is None:
         if scheduled_time is None:
-            scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
             weight = float(self.mode != AveragingMode.AUX)
-        deadline = get_dht_time() + timeout if timeout else None
+        deadline = get_dht_time() + timeout if timeout else float('inf')
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
         assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
         assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
         assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
 
 
         user_gather_bytes = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
         user_gather_bytes = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
         gather_binary = self.serializer.dumps([self.bandwidth, self.mode.value, user_gather_bytes])
         gather_binary = self.serializer.dumps([self.bandwidth, self.mode.value, user_gather_bytes])
-        step = StepControl(scheduled_time=scheduled_time, deadline=deadline, allow_retries=allow_retries,
-                           weight=weight, gather_binary=gather_binary)
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            gather_binary=gather_binary,
+        )
 
 
         future_for_trigger = MPFuture()
         future_for_trigger = MPFuture()
         self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
         self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
@@ -363,11 +379,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         return step.result() if wait else step
         return step.result() if wait else step
 
 
     async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
     async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
-        trigger = MPFuture()
-        step.attach_trigger(trigger)
-        future_for_trigger.set_result(trigger)
-
         try:
         try:
+            trigger = MPFuture()
+            step.attach_trigger(trigger)
+            future_for_trigger.set_result(trigger)
+
             while not step.done():
             while not step.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
@@ -455,7 +471,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                            self._mark_state_updated()
                     else:
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
                         async for _ in allreduce:  # trigger all-reduce by iterating
                             raise ValueError("aux peers should not receive averaged tensors")
                             raise ValueError("aux peers should not receive averaged tensors")
@@ -485,7 +501,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         """
         with self.lock_averaged_tensors:
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
@@ -525,19 +540,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         while True:
         while True:
             if self.allow_state_sharing:
             if self.allow_state_sharing:
+                self._state_updated.clear()
+                expiration_time = get_dht_time() + self.declare_state_period
                 asyncio.create_task(
                 asyncio.create_task(
                     asyncio.wait_for(
                     asyncio.wait_for(
                         self.dht.store(
                         self.dht.store(
                             download_key,
                             download_key,
                             subkey=self.peer_id.to_bytes(),
                             subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
                             value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            expiration_time=expiration_time,
                             return_future=True,
                             return_future=True,
                         ),
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - self.request_timeout,
                     )
                     )
                 )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+            except asyncio.TimeoutError:
+                pass
+
+    def _mark_state_updated(self):
+        self.last_updated = get_dht_time()
+        self._state_updated.set()
 
 
     async def rpc_download_state(
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -636,7 +660,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                         logger.info(f"Finished downloading state from {peer}")
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
                         return
                     except Exception as e:
                     except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")

+ 19 - 15
hivemind/averaging/control.py

@@ -5,18 +5,18 @@ from typing import Optional
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
-from hivemind.utils import MPFuture, DHTExpiration, get_logger
+from hivemind.utils import MPFuture, DHTExpiration, get_logger, get_dht_time
 
 
 
 
 logger = get_logger(__file__)
 logger = get_logger(__file__)
 
 
 
 
 class AveragingStage(Enum):
 class AveragingStage(Enum):
-    IDLE = 0               # still initializing
+    IDLE = 0  # still initializing
     LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
     LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
-    AWAITING_TRIGGER = 2   # waiting for user to set the trigger that allows running allreduce
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
     RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
     RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
-    FINISHED = 4           # either done or failed with exception
+    FINISHED = 4  # either done or failed with exception
 
 
 
 
 class StepControl(MPFuture):
 class StepControl(MPFuture):
@@ -30,12 +30,13 @@ class StepControl(MPFuture):
 
 
     """
     """
 
 
-    def __init__(self, scheduled_time: DHTExpiration, deadline: Optional[float], allow_retries: bool,
-                 weight: float, gather_binary: bytes):
+    def __init__(
+        self, scheduled_time: DHTExpiration, deadline: float, allow_retries: bool, weight: float, gather_binary: bytes
+    ):
         super().__init__()
         super().__init__()
         self._gather_binary, self._deadline, self._allow_retries = gather_binary, deadline, allow_retries
         self._gather_binary, self._deadline, self._allow_retries = gather_binary, deadline, allow_retries
         self._trigger: Optional[MPFuture] = None
         self._trigger: Optional[MPFuture] = None
-        self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
         self.stage = AveragingStage.IDLE
         self.stage = AveragingStage.IDLE
         self.scheduled_time = scheduled_time
         self.scheduled_time = scheduled_time
         self.weight = weight
         self.weight = weight
@@ -58,7 +59,7 @@ class StepControl(MPFuture):
 
 
     @property
     @property
     def scheduled_time(self) -> DHTExpiration:
     def scheduled_time(self) -> DHTExpiration:
-        return struct.unpack('d', self._metadata[0:8].numpy().data)[0]
+        return struct.unpack("d", self._shared_buffer[0:8].numpy().data)[0]
 
 
     @scheduled_time.setter
     @scheduled_time.setter
     def scheduled_time(self, scheduled_time):
     def scheduled_time(self, scheduled_time):
@@ -66,36 +67,36 @@ class StepControl(MPFuture):
             logger.warning("Changing scheduled time has no effect after all-reduce has already started")
             logger.warning("Changing scheduled time has no effect after all-reduce has already started")
         if scheduled_time >= self.deadline:
         if scheduled_time >= self.deadline:
             logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
             logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
-        struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
+        struct.pack_into("d", self._shared_buffer[0:8].numpy().data, 0, float(scheduled_time))
 
 
     @property
     @property
     def weight(self) -> float:
     def weight(self) -> float:
-        return struct.unpack('d', self._metadata[8:16].numpy().data)[0]
+        return struct.unpack("d", self._shared_buffer[8:16].numpy().data)[0]
 
 
     @weight.setter
     @weight.setter
     def weight(self, weight: float):
     def weight(self, weight: float):
         assert weight >= 0 and np.isfinite(weight)
         assert weight >= 0 and np.isfinite(weight)
         if self.began_allreduce:
         if self.began_allreduce:
             logger.warning("Changing weights has no effect after all-reduce has already started")
             logger.warning("Changing weights has no effect after all-reduce has already started")
-        struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
+        struct.pack_into("d", self._shared_buffer[8:16].numpy().data, 0, float(weight))
 
 
     @property
     @property
     def stage(self) -> AveragingStage:
     def stage(self) -> AveragingStage:
-        return AveragingStage(self._metadata[16].item())
+        return AveragingStage(self._shared_buffer[16].item())
 
 
     @stage.setter
     @stage.setter
     def stage(self, stage: AveragingStage):
     def stage(self, stage: AveragingStage):
         if stage == AveragingStage.RUNNING_ALLREDUCE:
         if stage == AveragingStage.RUNNING_ALLREDUCE:
             self.can_modify = False
             self.can_modify = False
-        self._metadata[16] = stage.value
+        self._shared_buffer[16] = stage.value
 
 
     @property
     @property
     def began_allreduce(self) -> bool:
     def began_allreduce(self) -> bool:
-        return bool(self._metadata[17].item())
+        return bool(self._shared_buffer[17].item())
 
 
     @began_allreduce.setter
     @began_allreduce.setter
     def began_allreduce(self, value: bool):
     def began_allreduce(self, value: bool):
-        self._metadata[17] = int(value)
+        self._shared_buffer[17] = int(value)
 
 
     @property
     @property
     def gather_binary(self) -> bytes:
     def gather_binary(self) -> bytes:
@@ -105,6 +106,9 @@ class StepControl(MPFuture):
     def deadline(self) -> DHTExpiration:
     def deadline(self) -> DHTExpiration:
         return self._deadline
         return self._deadline
 
 
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
     @property
     @property
     def allow_retries(self) -> bool:
     def allow_retries(self) -> bool:
         return self._allow_retries
         return self._allow_retries

+ 36 - 34
hivemind/averaging/matchmaking.py

@@ -9,6 +9,8 @@ import random
 from math import isfinite
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 
+import numpy as np
+
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
@@ -44,16 +46,16 @@ class Matchmaking:
         prefix: str,
         prefix: str,
         target_group_size: int,
         target_group_size: int,
         min_group_size: int,
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         request_timeout: float,
         client_mode: bool,
         client_mode: bool,
         initial_group_bits: str = "",
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
     ):
     ):
         assert "." not in prefix, "group prefix must be a string without ."
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
             )
 
 
         super().__init__()
         super().__init__()
@@ -68,7 +70,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
         self.client_mode = client_mode
 
 
         self.lock_looking_for_group = asyncio.Lock()
         self.lock_looking_for_group = asyncio.Lock()
@@ -79,11 +81,11 @@ class Matchmaking:
 
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
         self.step: Optional[StepControl] = None
         self.step: Optional[StepControl] = None
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    def looking_for_group(self, step: StepControl):
+    async def looking_for_group(self, step: StepControl):
         async with self.lock_looking_for_group:
         async with self.lock_looking_for_group:
             assert self.step is None
             assert self.step is None
             self.step = step
             self.step = step
@@ -121,7 +123,7 @@ class Matchmaking:
         async with self.looking_for_group(step):
         async with self.looking_for_group(step):
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=step.timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
             except asyncio.TimeoutError:
                 return None
                 return None
 
 
@@ -144,15 +146,15 @@ class Matchmaking:
                 self.assembled_group = asyncio.Future()
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
                 self.was_accepted_to_group.clear()
 
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         assert self.is_looking_for_group
         assert self.is_looking_for_group
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
             while True:
                 try:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self.request_join_group(next_leader)
                     if group is not None:
                     if group is not None:
                         return group
                         return group
 
 
@@ -173,26 +175,25 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                         self.assembled_group.set_exception(e)
                     raise e
                     raise e
 
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         """
         :param leader: request this peer to be your leader for allreduce
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
           The originally specified leader can disband group and redirect us to a different leader
         """
         """
         assert self.is_looking_for_group and self.current_leader is None
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
         try:
             async with self.lock_request_join_group:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
+                request_expiration_time = self.get_request_expiration_time()
                 stream = await leader_stub.rpc_join_group(
                 stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
-                        gather=self.control.gather_binary,
+                        gather=self.step.gather_binary,
                         group_key=self.group_key_manager.current_key,
                         group_key=self.group_key_manager.current_key,
                     )
                     )
                 )
                 )
@@ -211,7 +212,7 @@ class Matchmaking:
                 return None
                 return None
 
 
             async with self.potential_leaders.pause_search():
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -225,7 +226,7 @@ class Matchmaking:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         self.current_leader = None
                         await stream.aclose()
                         await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        return await self.request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 logger.debug(f"{self} - leader disbanded group")
                 return None
                 return None
 
 
@@ -244,6 +245,14 @@ class Matchmaking:
             if stream is not None:
             if stream is not None:
                 await stream.aclose()
                 await stream.aclose()
 
 
+    def get_request_expiration_time(self) -> float:
+        """this averager's current expiration time - used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
+
     async def rpc_join_group(
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -359,7 +368,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
         random.shuffle(ordered_peer_ids)
 
 
         gathered = tuple(
         gathered = tuple(
-            self.control.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
             for peer_id in ordered_peer_ids
         )
         )
 
 
@@ -395,8 +404,8 @@ class Matchmaking:
 class PotentialLeaders:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
     """An utility class that searches for averagers that could become our leaders"""
 
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -411,7 +420,7 @@ class PotentialLeaders:
     async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
     async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
         async with self.lock_search:
             self.running.set()
             self.running.set()
-            self.search_end_time = get_dht_time() + step.timeout if step.timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float('inf')
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
             if declare:
                 declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
                 declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
@@ -474,20 +483,12 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
             return maybe_next_leader
 
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         while get_dht_time() < self.search_end_time:
         while get_dht_time() < self.search_end_time:
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             self.max_assured_time = max(
             self.max_assured_time = max(
-                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
             )
             )
 
 
             self.leader_queue.clear()
             self.leader_queue.clear()
@@ -511,8 +512,9 @@ class PotentialLeaders:
             try:
             try:
                 while True:
                 while True:
                     await self.running.wait()
                     await self.running.wait()
-                    #TODO account for scheduled time here!
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(np.clip(step.scheduled_time,
+                                                        a_min=get_dht_time() + self.min_matchmaking_time,
+                                                        a_max=self.search_end_time))
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()
                     self.declared_expiration.set()