justheuristic 3 years ago
parent
commit
1f29b7882c
3 changed files with 98 additions and 69 deletions
  1. 43 20
      hivemind/averaging/averager.py
  2. 19 15
      hivemind/averaging/control.py
  3. 36 34
      hivemind/averaging/matchmaking.py

+ 43 - 20
hivemind/averaging/averager.py

@@ -29,7 +29,7 @@ from hivemind.compression import (
     serialize_torch_tensor,
 )
 from hivemind.dht import DHT, DHTID
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
@@ -55,8 +55,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param prefix: a shared prefix for all group keys
     :param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
     :param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
-    :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
-      note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
+    :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
@@ -93,6 +92,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     _matchmaking: Matchmaking
     _pending_group_assembled: asyncio.Event
+    _state_updated: asyncio.Event
+    _p2p: P2P
     serializer = MSGPackSerializer
 
     def __init__(
@@ -105,8 +106,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: int,
         min_group_size: int = 2,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
-        request_timeout: float = 3,
+        averaging_expiration: float = None,
+        min_matchmaking_time: float = 5.0,
+        request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
@@ -117,6 +119,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         min_vector_size: int = 0,
         auxiliary: bool = False,
         allow_state_sharing: Optional[bool] = None,
+        declare_state_period: float = 30,
         client_mode: Optional[bool] = None,
         daemon: bool = True,
         shutdown_timeout: float = 5,
@@ -130,6 +133,12 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
+        if averaging_expiration is not None:
+            logger.warning("averaging_expiration is deprecated and will be removed in v1.0.1, use min_matchmaking_time")
+            assert min_matchmaking_time == 5.0, "can't set both averaging_expiration and min_matchmaking_time"
+            min_matchmaking_time = averaging_expiration
+            del averaging_expiration
+
         super().__init__()
         self.dht = dht
         self.prefix = prefix
@@ -164,8 +173,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             initial_group_bits=initial_group_bits,
             target_group_size=target_group_size,
             min_group_size=min_group_size,
-            averaging_expiration=averaging_expiration,
             request_timeout=request_timeout,
+            min_matchmaking_time=min_matchmaking_time
         )
         self.allreduce_kwargs = dict(
             compression=compression,
@@ -181,6 +190,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
+        self.declare_state_period = declare_state_period
         self.state_compression = state_compression
         self.tensor_infos = tensor_infos
 
@@ -251,6 +261,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     asyncio.create_task(self._declare_for_download_periodically())
 
+                self._state_updated = asyncio.Event()
                 self._pending_group_assembled = asyncio.Event()
                 self._pending_group_assembled.set()
             except Exception as e:
@@ -341,18 +352,23 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
         if scheduled_time is None:
-            scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
-        deadline = get_dht_time() + timeout if timeout else None
+        deadline = get_dht_time() + timeout if timeout else float('inf')
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
         assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
 
         user_gather_bytes = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
         gather_binary = self.serializer.dumps([self.bandwidth, self.mode.value, user_gather_bytes])
-        step = StepControl(scheduled_time=scheduled_time, deadline=deadline, allow_retries=allow_retries,
-                           weight=weight, gather_binary=gather_binary)
+        step = StepControl(
+            scheduled_time=scheduled_time,
+            deadline=deadline,
+            allow_retries=allow_retries,
+            weight=weight,
+            gather_binary=gather_binary,
+        )
 
         future_for_trigger = MPFuture()
         self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
@@ -363,11 +379,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         return step.result() if wait else step
 
     async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
-        trigger = MPFuture()
-        step.attach_trigger(trigger)
-        future_for_trigger.set_result(trigger)
-
         try:
+            trigger = MPFuture()
+            step.attach_trigger(trigger)
+            future_for_trigger.set_result(trigger)
+
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
@@ -455,7 +471,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                        self.last_updated = get_dht_time()
+                            self._mark_state_updated()
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
                             raise ValueError("aux peers should not receive averaged tensors")
@@ -485,7 +501,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         """
         with self.lock_averaged_tensors:
             yield self._averaged_tensors
-        self.last_updated = get_dht_time()
 
     @contextlib.asynccontextmanager
     async def get_tensors_async(self) -> Sequence[torch.Tensor]:
@@ -525,19 +540,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
         while True:
             if self.allow_state_sharing:
+                self._state_updated.clear()
+                expiration_time = get_dht_time() + self.declare_state_period
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
                             value=self.last_updated,
-                            expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
+                            expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=self._matchmaking.averaging_expiration,
+                        timeout=expiration_time - self.request_timeout,
                     )
                 )
-            await asyncio.sleep(self._matchmaking.averaging_expiration)
+            try:
+                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+            except asyncio.TimeoutError:
+                pass
+
+    def _mark_state_updated(self):
+        self.last_updated = get_dht_time()
+        self._state_updated.set()
 
     async def rpc_download_state(
         self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
@@ -636,7 +660,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                         logger.info(f"Finished downloading state from {peer}")
                         future.set_result((metadata, tensors))
-                        self.last_updated = get_dht_time()
                         return
                     except Exception as e:
                         logger.exception(f"Failed to download state from {peer} - {repr(e)}")

+ 19 - 15
hivemind/averaging/control.py

@@ -5,18 +5,18 @@ from typing import Optional
 import numpy as np
 import torch
 
-from hivemind.utils import MPFuture, DHTExpiration, get_logger
+from hivemind.utils import MPFuture, DHTExpiration, get_logger, get_dht_time
 
 
 logger = get_logger(__file__)
 
 
 class AveragingStage(Enum):
-    IDLE = 0               # still initializing
+    IDLE = 0  # still initializing
     LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
-    AWAITING_TRIGGER = 2   # waiting for user to set the trigger that allows running allreduce
+    AWAITING_TRIGGER = 2  # waiting for user to set the trigger that allows running allreduce
     RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
-    FINISHED = 4           # either done or failed with exception
+    FINISHED = 4  # either done or failed with exception
 
 
 class StepControl(MPFuture):
@@ -30,12 +30,13 @@ class StepControl(MPFuture):
 
     """
 
-    def __init__(self, scheduled_time: DHTExpiration, deadline: Optional[float], allow_retries: bool,
-                 weight: float, gather_binary: bytes):
+    def __init__(
+        self, scheduled_time: DHTExpiration, deadline: float, allow_retries: bool, weight: float, gather_binary: bytes
+    ):
         super().__init__()
         self._gather_binary, self._deadline, self._allow_retries = gather_binary, deadline, allow_retries
         self._trigger: Optional[MPFuture] = None
-        self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self._shared_buffer = torch.zeros([18], dtype=torch.uint8).share_memory_()
         self.stage = AveragingStage.IDLE
         self.scheduled_time = scheduled_time
         self.weight = weight
@@ -58,7 +59,7 @@ class StepControl(MPFuture):
 
     @property
     def scheduled_time(self) -> DHTExpiration:
-        return struct.unpack('d', self._metadata[0:8].numpy().data)[0]
+        return struct.unpack("d", self._shared_buffer[0:8].numpy().data)[0]
 
     @scheduled_time.setter
     def scheduled_time(self, scheduled_time):
@@ -66,36 +67,36 @@ class StepControl(MPFuture):
             logger.warning("Changing scheduled time has no effect after all-reduce has already started")
         if scheduled_time >= self.deadline:
             logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
-        struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
+        struct.pack_into("d", self._shared_buffer[0:8].numpy().data, 0, float(scheduled_time))
 
     @property
     def weight(self) -> float:
-        return struct.unpack('d', self._metadata[8:16].numpy().data)[0]
+        return struct.unpack("d", self._shared_buffer[8:16].numpy().data)[0]
 
     @weight.setter
     def weight(self, weight: float):
         assert weight >= 0 and np.isfinite(weight)
         if self.began_allreduce:
             logger.warning("Changing weights has no effect after all-reduce has already started")
-        struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
+        struct.pack_into("d", self._shared_buffer[8:16].numpy().data, 0, float(weight))
 
     @property
     def stage(self) -> AveragingStage:
-        return AveragingStage(self._metadata[16].item())
+        return AveragingStage(self._shared_buffer[16].item())
 
     @stage.setter
     def stage(self, stage: AveragingStage):
         if stage == AveragingStage.RUNNING_ALLREDUCE:
             self.can_modify = False
-        self._metadata[16] = stage.value
+        self._shared_buffer[16] = stage.value
 
     @property
     def began_allreduce(self) -> bool:
-        return bool(self._metadata[17].item())
+        return bool(self._shared_buffer[17].item())
 
     @began_allreduce.setter
     def began_allreduce(self, value: bool):
-        self._metadata[17] = int(value)
+        self._shared_buffer[17] = int(value)
 
     @property
     def gather_binary(self) -> bytes:
@@ -105,6 +106,9 @@ class StepControl(MPFuture):
     def deadline(self) -> DHTExpiration:
         return self._deadline
 
+    def get_timeout(self) -> Optional[DHTExpiration]:
+        return max(0.0, self.deadline - get_dht_time())
+
     @property
     def allow_retries(self) -> bool:
         return self._allow_retries

+ 36 - 34
hivemind/averaging/matchmaking.py

@@ -9,6 +9,8 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
+import numpy as np
+
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
@@ -44,16 +46,16 @@ class Matchmaking:
         prefix: str,
         target_group_size: int,
         min_group_size: int,
+        min_matchmaking_time: float,
         request_timeout: float,
         client_mode: bool,
         initial_group_bits: str = "",
-        averaging_expiration: float = 15,
     ):
         assert "." not in prefix, "group prefix must be a string without ."
-        if request_timeout is None or request_timeout >= averaging_expiration:
+        if request_timeout is None or request_timeout >= min_matchmaking_time:
             logger.warning(
-                "It is recommended to use request_timeout smaller than averaging_expiration. Otherwise,"
-                "matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
+                "It is recommended to use request_timeout smaller than min_matchmaking_time. Otherwise,"
+                " matchmaking can cause deadlocks in some rare cases. Please see Matchmaking docstring."
             )
 
         super().__init__()
@@ -68,7 +70,7 @@ class Matchmaking:
         self.schema_hash = schema_hash
         self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
-        self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
+        self.min_matchmaking_time, self.request_timeout = min_matchmaking_time, request_timeout
         self.client_mode = client_mode
 
         self.lock_looking_for_group = asyncio.Lock()
@@ -79,11 +81,11 @@ class Matchmaking:
 
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
-        self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
+        self.potential_leaders = PotentialLeaders(self.peer_id, min_matchmaking_time, target_group_size)
         self.step: Optional[StepControl] = None
 
     @contextlib.asynccontextmanager
-    def looking_for_group(self, step: StepControl):
+    async def looking_for_group(self, step: StepControl):
         async with self.lock_looking_for_group:
             assert self.step is None
             self.step = step
@@ -121,7 +123,7 @@ class Matchmaking:
         async with self.looking_for_group(step):
             request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=step.timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.get_timeout())
             except asyncio.TimeoutError:
                 return None
 
@@ -144,15 +146,15 @@ class Matchmaking:
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
 
-    async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
+    async def _request_join_potential_leaders(self, step: StepControl) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         assert self.is_looking_for_group
-        async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
+        async with self.potential_leaders.begin_search(step, self.group_key_manager, declare=not self.client_mode):
             while True:
                 try:
                     next_leader = await self.potential_leaders.pop_next_leader()  # throws TimeoutError on expiration
 
-                    group = await self.request_join_group(next_leader, self.potential_leaders.request_expiration_time)
+                    group = await self.request_join_group(next_leader)
                     if group is not None:
                         return group
 
@@ -173,26 +175,25 @@ class Matchmaking:
                         self.assembled_group.set_exception(e)
                     raise e
 
-    async def request_join_group(self, leader: PeerID, expiration_time: DHTExpiration) -> Optional[GroupInfo]:
+    async def request_join_group(self, leader: PeerID) -> Optional[GroupInfo]:
         """
         :param leader: request this peer to be your leader for allreduce
-        :param expiration_time: inform leader that we intend to begin averaging before this expiration_time
         :returns: if leader leader accepted us and started AllReduce, return that AllReduce. Otherwise, return None
         :note: this function does not guarantee that your group leader is the same as :leader: parameter
           The originally specified leader can disband group and redirect us to a different leader
         """
         assert self.is_looking_for_group and self.current_leader is None
-        stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
+        stream: Optional[AsyncIterator[averaging_pb2.MessageFromLeader]] = None
         try:
             async with self.lock_request_join_group:
                 leader_stub = self._servicer_type.get_stub(self._p2p, leader, namespace=self._prefix)
-
+                request_expiration_time = self.get_request_expiration_time()
                 stream = await leader_stub.rpc_join_group(
                     averaging_pb2.JoinRequest(
                         schema_hash=self.schema_hash,
-                        expiration=expiration_time,
+                        expiration=request_expiration_time,
                         client_mode=self.client_mode,
-                        gather=self.control.gather_binary,
+                        gather=self.step.gather_binary,
                         group_key=self.group_key_manager.current_key,
                     )
                 )
@@ -211,7 +212,7 @@ class Matchmaking:
                 return None
 
             async with self.potential_leaders.pause_search():
-                time_to_expiration = max(expiration_time - get_dht_time(), 0.0)
+                time_to_expiration = max(0.0, request_expiration_time - get_dht_time())
                 message = await asyncio.wait_for(anext(stream), time_to_expiration + self.request_timeout)
 
                 if message.code == averaging_pb2.BEGIN_ALLREDUCE:
@@ -225,7 +226,7 @@ class Matchmaking:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
                         await stream.aclose()
-                        return await self.request_join_group(suggested_leader, expiration_time)
+                        return await self.request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
 
@@ -244,6 +245,14 @@ class Matchmaking:
             if stream is not None:
                 await stream.aclose()
 
+    def get_request_expiration_time(self) -> float:
+        """this averager's current expiration time - used to send join requests to leaders"""
+        if isfinite(self.potential_leaders.declared_expiration_time):
+            return self.potential_leaders.declared_expiration_time
+        else:
+            scheduled_time = max(self.step.scheduled_time, get_dht_time() + self.min_matchmaking_time)
+            return min(scheduled_time, self.potential_leaders.search_end_time)
+
     async def rpc_join_group(
         self, request: averaging_pb2.JoinRequest, context: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
@@ -359,7 +368,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.control.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.step.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
         )
 
@@ -395,8 +404,8 @@ class Matchmaking:
 class PotentialLeaders:
     """An utility class that searches for averagers that could become our leaders"""
 
-    def __init__(self, peer_id: PeerID, averaging_expiration: DHTExpiration, target_group_size: Optional[int]):
-        self.peer_id, self.averaging_expiration = peer_id, averaging_expiration
+    def __init__(self, peer_id: PeerID, min_matchmaking_time: DHTExpiration, target_group_size: Optional[int]):
+        self.peer_id, self.min_matchmaking_time = peer_id, min_matchmaking_time
         self.target_group_size = target_group_size
         self.running, self.update_triggered, self.update_finished = asyncio.Event(), asyncio.Event(), asyncio.Event()
         self.declared_expiration, self.lock_search, self.lock_declare = asyncio.Event(), asyncio.Lock(), asyncio.Lock()
@@ -411,7 +420,7 @@ class PotentialLeaders:
     async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + step.timeout if step.timeout is not None else float("inf")
+            self.search_end_time = step.deadline if step.deadline is not None else float('inf')
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
                 declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
@@ -474,20 +483,12 @@ class PotentialLeaders:
             self.past_attempts.add((maybe_next_leader, entry.expiration_time))
             return maybe_next_leader
 
-    @property
-    def request_expiration_time(self) -> float:
-        """this averager's current expiration time - used to send join requests to leaders"""
-        if isfinite(self.declared_expiration_time):
-            return self.declared_expiration_time
-        else:
-            return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
-
     async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
         DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
         while get_dht_time() < self.search_end_time:
             new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
             self.max_assured_time = max(
-                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+                self.max_assured_time, get_dht_time() + self.min_matchmaking_time - DISCREPANCY
             )
 
             self.leader_queue.clear()
@@ -511,8 +512,9 @@ class PotentialLeaders:
             try:
                 while True:
                     await self.running.wait()
-                    #TODO account for scheduled time here!
-                    new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
+                    new_expiration_time = float(np.clip(step.scheduled_time,
+                                                        a_min=get_dht_time() + self.min_matchmaking_time,
+                                                        a_max=self.search_end_time))
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration.set()