|
@@ -16,6 +16,7 @@ import numpy as np
|
|
|
import torch
|
|
|
|
|
|
from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
|
|
|
+from hivemind.averaging.control import AveragingStage, StepControl
|
|
|
from hivemind.averaging.group_info import GroupInfo
|
|
|
from hivemind.averaging.load_balancing import load_balance_peers
|
|
|
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
|
|
@@ -28,7 +29,7 @@ from hivemind.compression import (
|
|
|
serialize_torch_tensor,
|
|
|
)
|
|
|
from hivemind.dht import DHT, DHTID
|
|
|
-from hivemind.p2p import P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
+from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
|
|
|
from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
|
|
|
from hivemind.utils.asyncio import achain, aiter_with_timeout, anext, as_aiter, azip, switch_to_uvloop
|
|
@@ -54,8 +55,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:param prefix: a shared prefix for all group keys
|
|
|
:param target_group_size: attempts to form groups with up to this many peers (recommended: a power of 2, e.g. 16)
|
|
|
:param initial_group_bits: a string of bits ('0' and '1') that define the initial group key (bucket index)
|
|
|
- :param averaging_expiration: attempt to find a group for this many seconds, otherwise try again
|
|
|
- note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
|
|
|
+ :param min_matchmaking_time: when looking for group, wait for requests for at least this many seconds
|
|
|
:param compression: optionally compress tensors with this compression algorithm before running all-reduce
|
|
|
:param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
|
|
|
:param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
|
|
@@ -63,7 +63,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
:param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
|
|
|
towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
|
|
|
:param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
|
|
|
- :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
|
|
|
+ :note: request_timeout must be smaller than min_matchmaking_time to avoid potential deadlocks.
|
|
|
:param part_size_bytes: tensors for AllReduce are processed in parts of up to this size (after compression)
|
|
|
:param bandwidth: if specified, this value represents the network bandwidth available to averager.
|
|
|
By default, the averager is assumed to have the average bandwidth of his group.
|
|
@@ -75,6 +75,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
local tensors for averaging
|
|
|
:param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
|
|
|
with averager.allow_state_sharing = True / False
|
|
|
+ :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
|
|
|
:param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
|
|
|
|
|
|
Example:
|
|
@@ -92,6 +93,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
_matchmaking: Matchmaking
|
|
|
_pending_group_assembled: asyncio.Event
|
|
|
+ _state_updated: asyncio.Event
|
|
|
+ _p2p: P2P
|
|
|
serializer = MSGPackSerializer
|
|
|
|
|
|
def __init__(
|
|
@@ -104,8 +107,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
target_group_size: int,
|
|
|
min_group_size: int = 2,
|
|
|
initial_group_bits: str = "",
|
|
|
- averaging_expiration: float = 15,
|
|
|
- request_timeout: float = 3,
|
|
|
+ averaging_expiration: Optional[float] = None,
|
|
|
+ min_matchmaking_time: float = 5.0,
|
|
|
+ request_timeout: float = 3.0,
|
|
|
averaging_alpha: float = 1.0,
|
|
|
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
|
|
|
allreduce_timeout: Optional[float] = None,
|
|
@@ -116,6 +120,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
min_vector_size: int = 0,
|
|
|
auxiliary: bool = False,
|
|
|
allow_state_sharing: Optional[bool] = None,
|
|
|
+ declare_state_period: float = 30,
|
|
|
client_mode: Optional[bool] = None,
|
|
|
daemon: bool = True,
|
|
|
shutdown_timeout: float = 5,
|
|
@@ -129,6 +134,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
assert all(bit in "01" for bit in initial_group_bits)
|
|
|
assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
|
|
|
|
|
|
+ if averaging_expiration is not None:
|
|
|
+ logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
|
|
|
+ assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
|
|
|
+ min_matchmaking_time = averaging_expiration
|
|
|
+
|
|
|
super().__init__()
|
|
|
self.dht = dht
|
|
|
self.prefix = prefix
|
|
@@ -163,8 +173,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
initial_group_bits=initial_group_bits,
|
|
|
target_group_size=target_group_size,
|
|
|
min_group_size=min_group_size,
|
|
|
- averaging_expiration=averaging_expiration,
|
|
|
request_timeout=request_timeout,
|
|
|
+ min_matchmaking_time=min_matchmaking_time,
|
|
|
)
|
|
|
self.allreduce_kwargs = dict(
|
|
|
compression=compression,
|
|
@@ -180,6 +190,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if allow_state_sharing is None:
|
|
|
allow_state_sharing = not client_mode and not auxiliary
|
|
|
self.allow_state_sharing = allow_state_sharing
|
|
|
+ self.declare_state_period = declare_state_period
|
|
|
self.state_compression = state_compression
|
|
|
self.tensor_infos = tensor_infos
|
|
|
|
|
@@ -250,6 +261,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
if not self.client_mode:
|
|
|
asyncio.create_task(self._declare_for_download_periodically())
|
|
|
|
|
|
+ self._state_updated = asyncio.Event()
|
|
|
self._pending_group_assembled = asyncio.Event()
|
|
|
self._pending_group_assembled.set()
|
|
|
except Exception as e:
|
|
@@ -294,7 +306,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def shutdown(self) -> None:
|
|
|
"""Shut down the averager process"""
|
|
|
if self.is_alive():
|
|
|
- self._outer_pipe.send(("_shutdown", [None], {})) # shut down the daemon process
|
|
|
+ self._outer_pipe.send(("_shutdown", [self.shutdown_timeout], {})) # shut down the daemon process
|
|
|
self._inner_pipe.send(("_SHUTDOWN", None)) # shut down background thread in master
|
|
|
self.join(self.shutdown_timeout)
|
|
|
if self.is_alive():
|
|
@@ -303,11 +315,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
else:
|
|
|
logger.exception("Averager shutdown has no effect: the process is already not alive")
|
|
|
|
|
|
- async def _shutdown(self, timeout: Optional[float] = None) -> None:
|
|
|
+ async def _shutdown(self, timeout: Optional[float]) -> None:
|
|
|
remaining_tasks = set()
|
|
|
for group in self._running_groups.values():
|
|
|
remaining_tasks.update(group.finalize(cancel=True))
|
|
|
- await asyncio.gather(*remaining_tasks)
|
|
|
+ await asyncio.wait_for(asyncio.gather(*remaining_tasks), timeout)
|
|
|
|
|
|
def __del__(self):
|
|
|
if self._parent_pid == os.getpid() and self.is_alive():
|
|
@@ -316,68 +328,81 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
def step(
|
|
|
self,
|
|
|
gather: Optional[GatheredData] = None,
|
|
|
+ scheduled_time: Optional[DHTExpiration] = None,
|
|
|
weight: Optional[float] = None,
|
|
|
timeout: Optional[float] = None,
|
|
|
allow_retries: bool = True,
|
|
|
+ require_trigger: bool = False,
|
|
|
wait: bool = True,
|
|
|
- ) -> Union[Optional[Dict[PeerID, GatheredData]], MPFuture]:
|
|
|
+ ) -> Union[Optional[Dict[PeerID, GatheredData]], StepControl]:
|
|
|
"""
|
|
|
Set up the averager to look for a group and run one round of averaging, return True on success, False on failure
|
|
|
|
|
|
:param gather: optionally send this informaton to all peers in the next group and gather it from every groupmate
|
|
|
(this operation is known as all-gather). The gathered data will be available as the output of this function.
|
|
|
+ :param scheduled_time: when matchmaking, assume that all-reduce will begin at this moment.
|
|
|
+ By default, schedule all-reduce current time plus min_matchmaking_time seconds
|
|
|
:param weight: averaging weight for this peer, int or float, must be strictly positive
|
|
|
:param allow_retries: if averager fails to run one round of allreduce, this option will allow it to try again
|
|
|
within the specified timeout
|
|
|
- :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failedK
|
|
|
- :param wait: if True (default), return when finished. Otherwise return MPFuture and run in background.
|
|
|
+ :param require_trigger: if True, await for user to call .allow_allreduce() before running all-reduce
|
|
|
+ :param timeout: if averager was unable to *find* a group in this many seconds, consider allreduce failed
|
|
|
+ :param wait: if True (default), return when finished. Otherwise return StepControl and run in background.
|
|
|
:returns: on success, update averaged_tensors and return group info; on failure, return None
|
|
|
"""
|
|
|
if self.mode == AveragingMode.AUX and weight is not None:
|
|
|
logger.warning("Averager is running in auxiliary mode, weight is unused.")
|
|
|
+ if scheduled_time is None:
|
|
|
+ scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
|
|
|
if weight is None:
|
|
|
weight = float(self.mode != AveragingMode.AUX)
|
|
|
+ deadline = get_dht_time() + timeout if timeout is not None else float("inf")
|
|
|
assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
|
|
|
-
|
|
|
- future = MPFuture()
|
|
|
- gather_binary = self.serializer.dumps(
|
|
|
- gather
|
|
|
- ) # serialize here to avoid loading modules in the averager process
|
|
|
- self._outer_pipe.send(
|
|
|
- (
|
|
|
- "_step",
|
|
|
- [],
|
|
|
- dict(
|
|
|
- future=future,
|
|
|
- gather_binary=gather_binary,
|
|
|
- weight=weight,
|
|
|
- allow_retries=allow_retries,
|
|
|
- timeout=timeout,
|
|
|
- ),
|
|
|
- )
|
|
|
+ assert not (wait and require_trigger), "Non-asynchronous step cannot wait for trigger (use wait=False)"
|
|
|
+ assert scheduled_time < deadline, "Scheduled start time does not fit within timeout"
|
|
|
+
|
|
|
+ user_data_for_gather = self.serializer.dumps(gather) # serialize here to avoid imports in the averager process
|
|
|
+ data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, user_data_for_gather])
|
|
|
+ step = StepControl(
|
|
|
+ scheduled_time=scheduled_time,
|
|
|
+ deadline=deadline,
|
|
|
+ allow_retries=allow_retries,
|
|
|
+ weight=weight,
|
|
|
+ data_for_gather=data_for_gather,
|
|
|
)
|
|
|
- return future.result() if wait else future
|
|
|
|
|
|
- async def _step(
|
|
|
- self, *, future: MPFuture, gather_binary: bytes, weight: float, allow_retries: bool, timeout: Optional[float]
|
|
|
- ):
|
|
|
- start_time = get_dht_time()
|
|
|
+ future_for_trigger = MPFuture()
|
|
|
+ self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
|
|
|
+ step.attach_trigger(future_for_trigger.result())
|
|
|
|
|
|
+ if not require_trigger:
|
|
|
+ step.allow_allreduce()
|
|
|
+ return step.result() if wait else step
|
|
|
+
|
|
|
+ async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
|
|
|
try:
|
|
|
- while not future.done():
|
|
|
+ trigger = MPFuture()
|
|
|
+ step.attach_trigger(trigger)
|
|
|
+ future_for_trigger.set_result(trigger)
|
|
|
+
|
|
|
+ while not step.done():
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
- data_for_gather = self.serializer.dumps([self.bandwidth, self.mode.value, gather_binary])
|
|
|
- group_info = await self._matchmaking.look_for_group(
|
|
|
- timeout=timeout, data_for_gather=data_for_gather
|
|
|
- )
|
|
|
+ step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
+ group_info = await self._matchmaking.look_for_group(step)
|
|
|
if group_info is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
|
|
- future.set_result(
|
|
|
+ if not step.triggered:
|
|
|
+ step.stage = AveragingStage.AWAITING_TRIGGER
|
|
|
+ await step.wait_for_trigger()
|
|
|
+
|
|
|
+ step.stage = AveragingStage.RUNNING_ALLREDUCE
|
|
|
+
|
|
|
+ step.set_result(
|
|
|
await asyncio.wait_for(
|
|
|
self._run_allreduce(
|
|
|
- group_info, tensor_infos=self.tensor_infos, weight=weight, **self.allreduce_kwargs
|
|
|
+ group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
|
|
|
),
|
|
|
timeout=self._allreduce_timeout,
|
|
|
)
|
|
@@ -393,20 +418,20 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.InvalidStateError,
|
|
|
P2PHandlerError,
|
|
|
) as e:
|
|
|
- time_elapsed = get_dht_time() - start_time
|
|
|
- if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
- logger.exception(f"Averager caught {repr(e)}")
|
|
|
- future.set_exception(e)
|
|
|
+ if not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
+ logger.exception(e)
|
|
|
+ step.set_exception(e)
|
|
|
else:
|
|
|
- logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
+ logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
|
|
|
|
|
|
except BaseException as e:
|
|
|
- if not future.done():
|
|
|
- future.set_exception(e)
|
|
|
+ if not step.done():
|
|
|
+ step.set_exception(e)
|
|
|
raise
|
|
|
finally:
|
|
|
- if not future.done():
|
|
|
- future.set_exception(
|
|
|
+ step.stage = AveragingStage.FINISHED
|
|
|
+ if not step.done():
|
|
|
+ step.set_exception(
|
|
|
RuntimeError(
|
|
|
"Internal sanity check failed: averager.step left future pending."
|
|
|
" Please report this to hivemind issues."
|
|
@@ -416,8 +441,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
|
|
|
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
|
|
|
try:
|
|
|
- bandwidths, mode_ids, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
- user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered)))
|
|
|
+ bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
|
|
|
+ user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
|
|
|
modes = tuple(map(AveragingMode, mode_ids))
|
|
|
|
|
|
# compute optimal part sizes from peer bandwidths; TODO: replace with proper load balancing
|
|
@@ -447,7 +472,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
|
|
|
# all-reduce is performed asynchronously while iterating
|
|
|
tensor.add_(update, alpha=self._averaging_alpha)
|
|
|
- self.last_updated = get_dht_time()
|
|
|
+ self.last_updated = get_dht_time()
|
|
|
+ self._state_updated.set()
|
|
|
+
|
|
|
else:
|
|
|
async for _ in allreduce: # trigger all-reduce by iterating
|
|
|
raise ValueError("aux peers should not receive averaged tensors")
|
|
@@ -477,7 +504,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
"""
|
|
|
with self.lock_averaged_tensors:
|
|
|
yield self._averaged_tensors
|
|
|
- self.last_updated = get_dht_time()
|
|
|
|
|
|
@contextlib.asynccontextmanager
|
|
|
async def get_tensors_async(self) -> Sequence[torch.Tensor]:
|
|
@@ -517,19 +543,24 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
|
|
|
while True:
|
|
|
if self.allow_state_sharing:
|
|
|
+ self._state_updated.clear()
|
|
|
+ expiration_time = get_dht_time() + self.declare_state_period
|
|
|
asyncio.create_task(
|
|
|
asyncio.wait_for(
|
|
|
self.dht.store(
|
|
|
download_key,
|
|
|
subkey=self.peer_id.to_bytes(),
|
|
|
value=self.last_updated,
|
|
|
- expiration_time=get_dht_time() + self._matchmaking.averaging_expiration,
|
|
|
+ expiration_time=expiration_time,
|
|
|
return_future=True,
|
|
|
),
|
|
|
- timeout=self._matchmaking.averaging_expiration,
|
|
|
+ timeout=expiration_time - self.request_timeout,
|
|
|
)
|
|
|
)
|
|
|
- await asyncio.sleep(self._matchmaking.averaging_expiration)
|
|
|
+ try:
|
|
|
+ await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ pass
|
|
|
|
|
|
async def rpc_download_state(
|
|
|
self, _request: averaging_pb2.DownloadRequest, _context: P2PContext
|
|
@@ -584,10 +615,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
The exact contents of both metadata and tensors are determined by get_current_state method
|
|
|
"""
|
|
|
future = MPFuture()
|
|
|
- self._outer_pipe.send(("_load_state_from_peers", [], dict(future=future)))
|
|
|
+ self._outer_pipe.send(("_load_state_from_peers", [], dict(timeout=timeout, future=future)))
|
|
|
return future.result(timeout=timeout) if wait else future
|
|
|
|
|
|
- async def _load_state_from_peers(self, future: MPFuture):
|
|
|
+ async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
|
|
|
try:
|
|
|
key_manager = self._matchmaking.group_key_manager
|
|
|
peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
|
|
@@ -611,7 +642,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
|
|
|
current_tensor_parts, tensors = [], []
|
|
|
|
|
|
- async for message in aiter_with_timeout(stream, timeout=self.request_timeout):
|
|
|
+ async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
|
|
|
if message.metadata:
|
|
|
metadata = self.serializer.loads(message.metadata)
|
|
|
if message.tensor_part.dtype and current_tensor_parts:
|
|
@@ -628,7 +659,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
|
|
|
logger.info(f"Finished downloading state from {peer}")
|
|
|
future.set_result((metadata, tensors))
|
|
|
- self.last_updated = get_dht_time()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
logger.exception(f"Failed to download state from {peer} - {repr(e)}")
|