justheuristic 3 years ago
parent
commit
465989bed2
3 changed files with 163 additions and 49 deletions
  1. 35 33
      hivemind/averaging/averager.py
  2. 105 0
      hivemind/averaging/control.py
  3. 23 16
      hivemind/averaging/matchmaking.py

+ 35 - 33
hivemind/averaging/averager.py

@@ -16,6 +16,7 @@ import numpy as np
 import torch
 import torch
 
 
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
+from hivemind.averaging.control import StepControl, AveragingStage
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
@@ -303,7 +304,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         else:
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
             logger.exception("Averager shutdown has no effect: the process is already not alive")
 
 
-    async def _shutdown(self, timeout: Optional[float] = None) -> None:
+    async def _shutdown(self) -> None:
         remaining_tasks = set()
         remaining_tasks = set()
         for group in self._running_groups.values():
         for group in self._running_groups.values():
             remaining_tasks.update(group.finalize(cancel=True))
             remaining_tasks.update(group.finalize(cancel=True))
@@ -316,68 +317,68 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     def step(
     def step(
         self,
         self,
         gather: Optional[GatheredData] = None,
         gather: Optional[GatheredData] = None,
+        scheduled_time: Optional[DHTExpiration] = None,
         weight: Optional[float] = None,
         weight: Optional[float] = None,
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         allow_retries: bool = True,
         allow_retries: bool = True,
+        wait_for_trigger: bool = False,
         wait: bool = True,
         wait: bool = True,
-    ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
+    ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
         """
         """
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
         Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
 
 
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
         :param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
           (this operation is known as all-gather). The gathered data will be available as the output of this function.
+        :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param weight: averaging weight for this peer, int or float, must be strictly positive
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
         :param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
           within the specified timeout
           within the specified timeout
+        :param wait_for_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
         :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
-        :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
+        :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         """
         if self.mode == AveragingMode.AUX and weight is not None:
         if self.mode == AveragingMode.AUX and weight is not None:
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
             logger.warning("Averager is running in auxiliary mode, weight is unused.")
+        if scheduled_time is None:
+            scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
         if weight is None:
         if weight is None:
             weight = float(self.mode != AveragingMode.AUX)
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
+        assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
+        gather_binary = self.serializer.dumps(gather)  # serialize here to avoid imports in the averager process
 
 
-        future = MPFuture()
-        gather_binary = self.serializer.dumps(
-            gather
-        )  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(
-            (
-                "_step",
-                [],
-                dict(
-                    future=future,
-                    gather_binary=gather_binary,
-                    weight=weight,
-                    allow_retries=allow_retries,
-                    timeout=timeout,
-                ),
-            )
-        )
-        return future.result() if wait else future
+        step = StepControl(scheduled_time, weight, wait_for_trigger=wait_for_trigger,
+                           gather_binary=gather_binary, timeout=timeout, allow_retries=allow_retries)
+        self._outer_pipe.send(("_step", [], dict(step=step)))
+        return step.result() if wait else step
 
 
     async def _step(
     async def _step(
-        self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
+        self, *, step: StepControl, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]
     ):
     ):
         start_time = get_dht_time()
         start_time = get_dht_time()
 
 
         try:
         try:
-            while not future.done():
+            while not step.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
                     data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
                     data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
-                    group_info = await self._matchmaking.look_for_group(
-                        timeout=timeout, data_for_gather=data_for_gather
-                    )
+
+                    step.stage = AveragingStage.LOOKING_FOR_GROUP
+                    group_info = await self._matchmaking.look_for_group(step)
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
 
-                    future.set_result(
+                    if not step.done():
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+
+                    await step.wait_for_trigger()
+                    step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                    step.set_result(
                         await asyncio.wait_for(
                         await asyncio.wait_for(
                             self._run_allreduce(
                             self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
+                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
                             ),
                             ),
                             timeout=self._allreduce_timeout,
                             timeout=self._allreduce_timeout,
                         )
                         )
@@ -396,17 +397,18 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     time_elapsed = get_dht_time() - start_time
                     time_elapsed = get_dht_time() - start_time
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
                     if not allow_retries or (timeout is not None and timeout < time_elapsed):
                         logger.exception(f"Averager caught {repr(e)}")
                         logger.exception(f"Averager caught {repr(e)}")
-                        future.set_exception(e)
+                        step.set_exception(e)
                     else:
                     else:
                         logger.warning(f"Averager caught {repr(e)}, retrying")
                         logger.warning(f"Averager caught {repr(e)}, retrying")
 
 
         except BaseException as e:
         except BaseException as e:
-            if not future.done():
-                future.set_exception(e)
+            if not step.done():
+                step.set_exception(e)
             raise
             raise
         finally:
         finally:
-            if not future.done():
-                future.set_exception(
+            step.stage = AveragingStage.FINISHED
+            if not step.done():
+                step.set_exception(
                     RuntimeError(
                     RuntimeError(
                         "Internal sanity check failed: averager.step left future pending."
                         "Internal sanity check failed: averager.step left future pending."
                         " Please report this to hivemind issues."
                         " Please report this to hivemind issues."

+ 105 - 0
hivemind/averaging/control.py

@@ -0,0 +1,105 @@
+import struct
+from enum import Enum
+from typing import Optional
+
+import numpy as np
+import torch
+
+from hivemind.utils import MPFuture, DHTExpiration
+
+
+class AveragingStage(Enum):
+    IDLE = 0               # still initializing
+    LOOKING_FOR_GROUP = 1  # running decentralized matchmaking, can't run allreduce yet
+    AWAITING_TRIGGER = 2   # waiting for user to set the trigger that allows running allreduce
+    RUNNING_ALLREDUCE = 3  # exchanging tensors with groupmates
+    FINISHED = 4           # either done or failed with exception
+
+
+class StepControl(MPFuture):
+    """
+    An auxiliary data structure that allows user to control stages and track progress in a single averaging step
+    TODO description
+    :param gather_binary: optionally send this data to all peers in the next group and gather it from groupmates
+    :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+    :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
+
+
+    """
+    def __init__(self, scheduled_time: DHTExpiration, weight: float, wait_for_trigger: bool,
+                 gather_binary: bytes, timeout: Optional[float], allow_retries: bool):
+        super().__init__()
+        self._gather_binary, self._timeout, self._allow_retries = gather_binary, timeout, allow_retries
+        self._trigger: Optional[MPFuture] = None
+        if not wait_for_trigger:
+            self.allow_allreduce()
+        self._metadata = torch.zeros([18], dtype=torch.uint8).share_memory_()
+        self.stage = AveragingStage.IDLE
+        self.scheduled_time = scheduled_time
+        self.weight = weight
+        self.can_modify = True
+
+    def _attach_trigger(self, trigger: MPFuture):
+        assert self._trigger is None
+        self._trigger = trigger
+
+    def allow_allreduce(self):
+        """Allows averager to begin allreduce when it finds a group."""
+        self._trigger.set_result(None)
+
+    async def wait_for_trigger(self):
+        await self._trigger
+
+    @property
+    def scheduled_time(self) -> DHTExpiration:
+        return struct.unpack('d', self._metadata[0:8].numpy().data)[0]
+
+    @scheduled_time.setter
+    def scheduled_time(self, scheduled_time):
+        assert self.can_modify, "cannot change scheduling after all-reduce has already started"
+        #TODO check that scheduled time is still within timeout
+        struct.pack_into('d', self._metadata[0:8].numpy().data, 0, float(scheduled_time))
+
+    @property
+    def weight(self) -> float:
+        return struct.unpack('d', self._metadata[8:16].numpy().data)[0]
+
+    @weight.setter
+    def weight(self, weight: float):
+        assert self.can_modify, "cannot change weights after all-reduce has already started"
+        assert weight >= 0 and np.isfinite(weight)
+        struct.pack_into('d', self._metadata[8:16].numpy().data, 0, float(weight))
+
+    @property
+    def stage(self) -> AveragingStage:
+        return AveragingStage(self._metadata[16].item())
+
+    @stage.setter
+    def stage(self, stage: AveragingStage):
+        if stage == AveragingStage.RUNNING_ALLREDUCE:
+            self.can_modify = False
+        self._metadata[16] = stage.value
+
+    @property
+    def can_modify(self) -> bool:
+        return bool(self._metadata[17].item())
+
+    @can_modify.setter
+    def can_modify(self, value: bool):
+        self._metadata[17] = int(value)
+
+    @property
+    def gather_binary(self) -> bytes:
+        return self._gather_binary
+
+    @property
+    def timeout(self) -> DHTExpiration:
+        return self.timeout
+
+    @property
+    def allow_retries(self) -> bool:
+        return self._allow_retries
+
+    def cancel(self) -> bool:
+        self._trigger.cancel()
+        return self.cancel()

+ 23 - 16
hivemind/averaging/matchmaking.py

@@ -9,6 +9,7 @@ import random
 from math import isfinite
 from math import isfinite
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 from typing import AsyncIterator, Dict, Optional, Set, Tuple, Type
 
 
+from hivemind.averaging.control import StepControl
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.dht import DHT, DHTID, DHTExpiration
@@ -79,7 +80,15 @@ class Matchmaking:
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_leader: Optional[PeerID] = None  # iff i am a follower, this is a link to my current leader
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.current_followers: Dict[PeerID, averaging_pb2.JoinRequest] = {}  # my current followers excluding myself
         self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
         self.potential_leaders = PotentialLeaders(self.peer_id, averaging_expiration, target_group_size)
-        self.data_for_gather: Optional[bytes] = None
+        self.step: Optional[StepControl] = None
+
+    @contextlib.asynccontextmanager
+    def looking_for_group(self, step: StepControl):
+        async with self.lock_looking_for_group:
+            assert self.step is None
+            self.step = step
+            yield
+            self.step = None
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):
@@ -98,10 +107,9 @@ class Matchmaking:
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
             f" current key = {self.group_key_manager.current_key}, client_mode={self.client_mode})"
         )
         )
 
 
-    async def look_for_group(self, *, data_for_gather: bytes, timeout: Optional[float] = None) -> Optional[GroupInfo]:
+    async def look_for_group(self, step: StepControl) -> Optional[GroupInfo]:
         """
         """
-        :param data_for_gather: optionally send this data to all peers in the next group and gather it from groupmates
-        :param timeout: maximum time that may be spent looking for group (does not include allreduce itself)
+        :param step: step parameters and user control structure for the current step
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         :returns: an assembled group if successful, None if failed; does NOT perform the actual averaging
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         Iterate over the averagers from a given group_identifier that have higher leadership priority than yourself.
         """
         """
@@ -110,11 +118,10 @@ class Matchmaking:
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 "Another look_for_group is already in progress. The current run will be scheduled after"
                 " the existing group is either assembled or disbanded."
                 " the existing group is either assembled or disbanded."
             )
             )
-        async with self.lock_looking_for_group:
-            self.data_for_gather = data_for_gather
-            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(timeout))
+        async with self.looking_for_group(step):
+            request_leaders_task = asyncio.create_task(self._request_join_potential_leaders(step))
             try:
             try:
-                return await asyncio.wait_for(self.assembled_group, timeout=timeout)
+                return await asyncio.wait_for(self.assembled_group, timeout=step.timeout)
             except asyncio.TimeoutError:
             except asyncio.TimeoutError:
                 return None
                 return None
 
 
@@ -136,10 +143,10 @@ class Matchmaking:
                 # note: the code above ensures that we send all followers away before creating new future
                 # note: the code above ensures that we send all followers away before creating new future
                 self.assembled_group = asyncio.Future()
                 self.assembled_group = asyncio.Future()
                 self.was_accepted_to_group.clear()
                 self.was_accepted_to_group.clear()
-                self.data_for_gather = None
 
 
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
     async def _request_join_potential_leaders(self, timeout: Optional[float]) -> GroupInfo:
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
         """Request leaders from queue until we find the first runner. This coroutine is meant to run in background."""
+        assert self.is_looking_for_group
         async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
         async with self.potential_leaders.begin_search(self.group_key_manager, timeout, declare=not self.client_mode):
             while True:
             while True:
                 try:
                 try:
@@ -185,7 +192,7 @@ class Matchmaking:
                         schema_hash=self.schema_hash,
                         schema_hash=self.schema_hash,
                         expiration=expiration_time,
                         expiration=expiration_time,
                         client_mode=self.client_mode,
                         client_mode=self.client_mode,
-                        gather=self.data_for_gather,
+                        gather=self.control.gather_binary,
                         group_key=self.group_key_manager.current_key,
                         group_key=self.group_key_manager.current_key,
                     )
                     )
                 )
                 )
@@ -352,7 +359,7 @@ class Matchmaking:
         random.shuffle(ordered_peer_ids)
         random.shuffle(ordered_peer_ids)
 
 
         gathered = tuple(
         gathered = tuple(
-            self.data_for_gather if peer_id == self.peer_id else self.current_followers[peer_id].gather
+            self.control.gather_binary if peer_id == self.peer_id else self.current_followers[peer_id].gather
             for peer_id in ordered_peer_ids
             for peer_id in ordered_peer_ids
         )
         )
 
 
@@ -401,13 +408,13 @@ class PotentialLeaders:
         self.search_end_time = float("inf")
         self.search_end_time = float("inf")
 
 
     @contextlib.asynccontextmanager
     @contextlib.asynccontextmanager
-    async def begin_search(self, key_manager: GroupKeyManager, timeout: Optional[float], declare: bool = True):
+    async def begin_search(self, step: StepControl, key_manager: GroupKeyManager, declare: bool = True):
         async with self.lock_search:
         async with self.lock_search:
             self.running.set()
             self.running.set()
-            self.search_end_time = get_dht_time() + timeout if timeout is not None else float("inf")
+            self.search_end_time = get_dht_time() + step.timeout if step.timeout is not None else float("inf")
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             update_queue_task = asyncio.create_task(self._update_queue_periodically(key_manager))
             if declare:
             if declare:
-                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(key_manager))
+                declare_averager_task = asyncio.create_task(self._declare_averager_periodically(step, key_manager))
 
 
             try:
             try:
                 yield self
                 yield self
@@ -499,12 +506,12 @@ class PotentialLeaders:
             )
             )
             self.update_triggered.clear()
             self.update_triggered.clear()
 
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
+    async def _declare_averager_periodically(self, step: StepControl, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
         async with self.lock_declare:
             try:
             try:
                 while True:
                 while True:
                     await self.running.wait()
                     await self.running.wait()
-
+                    #TODO account for scheduled time here!
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
                     new_expiration_time = min(get_dht_time() + self.averaging_expiration, self.search_end_time)
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_group_key = group_key = key_manager.current_key
                     self.declared_expiration_time = new_expiration_time
                     self.declared_expiration_time = new_expiration_time