justheuristic há 3 anos atrás
pai
commit
465989bed2
3 ficheiros alterados com 163 adições e 49 exclusões
  1. 35 33
      hivemind/averaging/averager.py
  2. 105 0
      hivemind/averaging/control.py
  3. 23 16
      hivemind/averaging/matchmaking.py

+ 35 - 33
hivemind/averaging/averager.py

@@ -16,6 +16,7 @@ import numpy as np
 import torch
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import StepControl, AveragingStage
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -303,7 +304,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self) -> None:
         remaining_tasks = set()
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
@@ -316,68 +317,68 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
         self,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
+        wait_for_trigger: bool = False,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
+        :param wait_for_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
+        assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
 
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
-        )
-        return future.result() if wait else future
+        step = StepControl(scheduled_time, weight, wait_for_trigger=wait_for_trigger,
+                           gather_binary=gather_binary, timeout=timeout, allow_retries=allow_retries)
+        self._outer_pipe.send(("_step", [], dict(step=step)))
+        return step.result() if wait else step
 
     async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
+        self, *, step: StepControl, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]
     ):
         start_time = get_dht_time()
 
         try:
-            while not future.done():
+            while not step.done():
                 try:
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    group_info = await self._matchmaking.look_for_group(step)
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
-                    future.set_result(
+                    if not step.done():
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+
+                    await step.wait_for_trigger()
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
                             self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
                             ),
                             timeout=self._allreduce_timeout,
                         )
@@ -396,17 +397,18 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
                         logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                        step.set_exception(e)
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."

+ 105 - 0
hivemind/averaging/control.py

@@ -0,0 +1,105 @@
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import MPFuture, DHTExpiration
+
+
+class AveragingStage(Enum):
+    IDLE = 0               # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2   # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4           # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+    TODO description
+    :param gather_binary: optionally send this data to all peers in the next group and gather it from groupmates
+    :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+    :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
+
+
+    """
+    def __init__(self, scheduled_time: DHTExpiration, weight: float, wait_for_trigger: bool,
+                 gather_binary: bytes, timeout: Optional[float], allow_retries: bool):
+        super().__init__()
+        self._gather_binary, self._timeout, self._allow_retries = gather_binary, timeout, allow_retries
+        self._trigger: Optional[MPFuture] = None
+        if not wait_for_trigger:
+            self.allow_allreduce()
+        self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.can_modify = True
+
+    def _attach_trigger(self, trigger: MPFuture):
+        assert self._trigger is None
+        self._trigger = trigger
+
+    def allow_allreduce(self):
+        """Allows averager to begin allreduce when it finds a group."""
+        self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        await self._trigger
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack('d', self._metadata[0:8].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        assert self.can_modify, "cannot change scheduling after all-reduce has already started"
+        #TODO check that scheduled time is still within timeout
+        struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack('d', self._metadata[8:16].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert self.can_modify, "cannot change weights after all-reduce has already started"
+        assert weight >= 0 and np.isfinite(weight)
+        struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._metadata[16].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.can_modify = False
+        self._metadata[16] = stage.value
+
+    @property
+    def can_modify(self) -> bool:
+        return bool(self._metadata[17].item())
+
+    @can_modify.setter
+    def can_modify(self, value: bool):
+        self._metadata[17] = int(value)
+
+    @property
+    def gather_binary(self) -> bytes:
+        return self._gather_binary
+
+    @property
+    def timeout(self) -> DHTExpiration:
+        return self.timeout
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def cancel(self) -> bool:
+        self._trigger.cancel()
+        return self.cancel()

+ 23 - 16
hivemind/averaging/matchmaking.py

@@ -9,6 +9,7 @@ import random
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
@@ -79,7 +80,15 @@ class Matchmaking:
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.step: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    def looking_for_group(self, step: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step is None
+            self.step = step
+            yield
+            self.step = None
 
     @property
     def is_looking_for_group(self):
@@ -98,10 +107,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
@@ -110,11 +118,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.timeout)
             except asyncio.TimeoutError:
                 return None
 
@@ -136,10 +143,10 @@ class Matchmaking:
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
+        assert self.is_looking_for_group
         async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
             while True:
                 try:
@@ -185,7 +192,7 @@ class Matchmaking:
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.control.gather_binary,
                         group_key=self.group_key_manager.current_key,
                     )
                 )
@@ -352,7 +359,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
 
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.control.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
         )
 
@@ -401,13 +408,13 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
 
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = get_dht_time() + step.timeout if step.timeout is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
             try:
                 yield self
@@ -499,12 +506,12 @@ class PotentialLeaders:
             )
             self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
                     await self.running.wait()
-
+                    #TODO account for scheduled time here!
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time