|
@@ -16,6 +16,7 @@ import numpy as np
|
|
|
import torch
|
|
|
|
|
|
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
|
+from hivemind.averaging.control import StepControl, AveragingStage
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
@@ -303,7 +304,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
else:
|
|
|
logger.exception("Averager shutdown has no effect: the process is already not alive")
|
|
|
|
|
|
- async def _shutdown(self, timeout: Optional[float] = None) -> None:
|
|
|
+ async def _shutdown(self) -> None:
|
|
|
remaining_tasks = set()
|
|
|
for group in self._running_groups.values():
|
|
|
remaining_tasks.update(group.finalize(cancel=True))
|
|
@@ -316,68 +317,68 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def step(
|
|
|
self,
|
|
|
gather: Optional[GatheredData] = None,
|
|
|
+ scheduled_time: Optional[DHTExpiration] = None,
|
|
|
weight: Optional[float] = None,
|
|
|
timeout: Optional[float] = None,
|
|
|
allow_retries: bool = True,
|
|
|
+ wait_for_trigger: bool = False,
|
|
|
wait: bool = True,
|
|
|
- ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
|
|
|
+ ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
|
|
|
"""
|
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
|
|
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
|
+ :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment
|
|
|
:param weight: averaging weight for this peer, int or float, must be strictly positive
|
|
|
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
|
within the specified timeout
|
|
|
+ :param wait_for_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
|
|
|
:param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
|
|
|
- :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
+ :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
|
if self.mode == AveragingMode.AUX and weight is not None:
|
|
|
logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
+ if scheduled_time is None:
|
|
|
+ scheduled_time = get_dht_time() + self.matchmaking_kwargs["averaging_expiration"]
|
|
|
if weight is None:
|
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
+ assert not wait_for_trigger or wait, "Non-asynchronous step cannot wait for trigger (use wait=False)"
|
|
|
+ gather_binary = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
|
|
|
|
|
|
- future = MPFuture()
|
|
|
- gather_binary = self.serializer.dumps(
|
|
|
- gather
|
|
|
- ) # serialize here to avoid loading modules in the averager process
|
|
|
- self._outer_pipe.send(
|
|
|
- (
|
|
|
- "_step",
|
|
|
- [],
|
|
|
- dict(
|
|
|
- future=future,
|
|
|
- gather_binary=gather_binary,
|
|
|
- weight=weight,
|
|
|
- allow_retries=allow_retries,
|
|
|
- timeout=timeout,
|
|
|
- ),
|
|
|
- )
|
|
|
- )
|
|
|
- return future.result() if wait else future
|
|
|
+ step = StepControl(scheduled_time, weight, wait_for_trigger=wait_for_trigger,
|
|
|
+ gather_binary=gather_binary, timeout=timeout, allow_retries=allow_retries)
|
|
|
+ self._outer_pipe.send(("_step", [], dict(step=step)))
|
|
|
+ return step.result() if wait else step
|
|
|
|
|
|
async def _step(
|
|
|
- self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
|
|
|
+ self, *, step: StepControl, gather_binary: bytes, allow_retries: bool, timeout: Optional[float]
|
|
|
):
|
|
|
start_time = get_dht_time()
|
|
|
|
|
|
try:
|
|
|
- while not future.done():
|
|
|
+ while not step.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
|
|
|
- group_info = await self._matchmaking.look_for_group(
|
|
|
- timeout=timeout, data_for_gather=data_for_gather
|
|
|
- )
|
|
|
+
|
|
|
+ step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
+ group_info = await self._matchmaking.look_for_group(step)
|
|
|
if group_info is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
|
|
- future.set_result(
|
|
|
+ if not step.done():
|
|
|
+ step.stage = AveragingStage.AWAITING_TRIGGER
|
|
|
+
|
|
|
+ await step.wait_for_trigger()
|
|
|
+ step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
+
|
|
|
+ step.set_result(
|
|
|
await asyncio.wait_for(
|
|
|
self._run_allreduce(
|
|
|
- group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
|
|
|
+ group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
),
|
|
|
timeout=self._allreduce_timeout,
|
|
|
)
|
|
@@ -396,17 +397,18 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
time_elapsed = get_dht_time() - start_time
|
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
logger.exception(f"Averager caught {repr(e)}")
|
|
|
- future.set_exception(e)
|
|
|
+ step.set_exception(e)
|
|
|
else:
|
|
|
logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
except BaseException as e:
|
|
|
- if not future.done():
|
|
|
- future.set_exception(e)
|
|
|
+ if not step.done():
|
|
|
+ step.set_exception(e)
|
|
|
raise
|
|
|
finally:
|
|
|
- if not future.done():
|
|
|
- future.set_exception(
|
|
|
+ step.stage = AveragingStage.FINISHED
|
|
|
+ if not step.done():
|
|
|
+ step.set_exception(
|
|
|
RuntimeError(
|
|
|
"Internal sanity check failed: averager.step left future pending."
|
|
|
" Please report this to hivemind issues."
|