import asyncio from enum import Enum from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type import torch from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase from hivemind.proto import averaging_pb2 from hivemind.utils import get_logger from hivemind.utils.asyncio import ( achain, aenumerate, afirst, amap_in_executor, anext, as_aiter, attach_event_on_finished, ) # flavour types GroupID = bytes logger = get_logger(__name__) class AveragingMode(Enum): NODE = 0 CLIENT = 1 AUX = 2 class AllReduceRunner(ServicerBase): """ An internal class that runs butterfly AllReduce in a predefined group of averagers. This class inherits hivemind.p2p.ServicerBase, so it can be used as an RPCServicer for testing purposes without creating a full DecentralizedAverager. :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability :param p2p: a hivemind.p2p.P2P instance used for communication with other peers :param servicer_type: a hivemind.p2p.ServicerBase subclass whose RPC signatures are used when requesting other peers. Typically, it is DecentralizedAverager, its derivative, or AllReduceRunner itself (for testing purposes). :param prefix: namespace for servicer's RPCs (typically, equal to prefix for group keys) :param group_id: unique identifier of this specific all-reduce run :param tensors: local tensors that should be averaged with groupmates :param weight: scalar weight of this peer's tensors in the average (doesn't need to sum up to 1) :param peer_id: your peer_id, must be included in ordered_peer_ids :param ordered_peer_ids: group peer_ids ordered s.t. i-th peer_id is responsible for averaging i-th part :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average (the actual number of values by peer will be nearly proportional, but there are no exact guarantees) :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary) :param gathered: additional user-defined data collected from this group :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn, non-averaging peers receive results only after they finish sending, which helps them avoid throughput issues in case of asymmetric high-latency connections (e.g. ACK compression). """ def __init__( self, *, p2p: P2P, servicer_type: Type[ServicerBase], prefix: Optional[str], group_id: GroupID, tensors: Sequence[torch.Tensor], weight: Optional[float] = None, ordered_peer_ids: Sequence[PeerID], peer_fractions: Tuple[float, ...], modes: Optional[Sequence[AveragingMode]] = None, gathered: Optional[Dict[PeerID, Any]] = None, **kwargs, ): self._p2p = p2p self.peer_id = p2p.peer_id assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group" if not issubclass(servicer_type, ServicerBase): raise TypeError("`servicer_type` is expected to be a ServicerBase subclass") self._servicer_type = servicer_type self._prefix = prefix modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions) assert len(modes) == len(ordered_peer_ids), "lists have inconsistent length" assert any(mode != AveragingMode.CLIENT for mode in modes), "cannot run allreduce without reducers" for mode, frac in zip(modes, peer_fractions): assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction" self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered if weight is None: weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX) self.weight = weight self._future = asyncio.Future() self.sender_peer_ids = [] for peer_id, mode in zip(self.ordered_peer_ids, modes): if mode != AveragingMode.AUX: self.sender_peer_ids.append(peer_id) peer_id_index = self.ordered_peer_ids.index(self.peer_id) self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs) self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index) self.tensor_part_reducer = TensorPartReducer( tuple(part.shape for part in self.parts_for_local_averaging), len(self.sender_peer_ids), ) def __repr__(self): return f"{self.__class__.__name__}({self.peer_id}, group_size={self.group_size})" def __aiter__(self): return self.run() def __contains__(self, peer_id: PeerID): return peer_id in self.ordered_peer_ids @property def group_size(self): return len(self.ordered_peer_ids) def _get_peer_stub(self, peer: PeerID) -> StubBase: return self._servicer_type.get_stub(self._p2p, peer, namespace=self._prefix) def should_delay_results(self, peer_id: PeerID) -> bool: return self.peer_fractions[self.ordered_peer_ids.index(peer_id)] == 0 async def run(self) -> AsyncIterator[torch.Tensor]: """Run all-reduce, return differences between averaged and original tensors as they are computed""" pending_tasks = set() try: if len(self.sender_peer_ids) == 0: logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})") self.finalize() elif self.peer_id in self.sender_peer_ids: for peer_id, parts in zip(self.ordered_peer_ids, self.tensor_part_container.num_parts_by_peer): if parts != 0: pending_tasks.add(asyncio.create_task(self._communicate_with_peer(peer_id))) async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors(): yield averaged_tensor_delta # delta = averaged_tensor - original_tensor self.finalize() else: # auxiliary peer await self.tensor_part_reducer.finished.wait() self.finalize() except BaseException as e: self.finalize(exception=e) for task in pending_tasks: task.cancel() raise async def _communicate_with_peer(self, peer_id: PeerID): """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors""" peer_index = self.ordered_peer_ids.index(peer_id) if peer_id == self.peer_id: sender_index = self.sender_peer_ids.index(peer_id) for part_index, tensor_part in enumerate(self.parts_for_local_averaging): averaged_part = await self.tensor_part_reducer.accumulate_part( sender_index, part_index, tensor_part, weight=self.weight ) self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part) else: code = None stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index)) async for part_index, (averaged_part_delta, msg) in aenumerate( amap_in_executor( lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg), stream, max_prefetch=self.tensor_part_container.prefetch, ) ): if code is None: code = msg.code self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta) if code != averaging_pb2.AVERAGED_PART: raise AllreduceException( f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} " f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}" f", allreduce failed" ) async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]: parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index) first_part = await anext(parts_aiter) yield averaging_pb2.AveragingData( code=averaging_pb2.PART_FOR_AVERAGING, group_id=self.group_id, tensor_part=first_part, weight=self.weight, ) async for part in parts_aiter: yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight) async def rpc_aggregate_part( self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext ) -> AsyncIterator[averaging_pb2.AveragingData]: """a peer sends us a part of his tensor; we should average it with other peers and return the difference""" request: averaging_pb2.AveragingData = await anext(stream) reason_to_reject = self._check_reasons_to_reject(request) if reason_to_reject: yield reason_to_reject return elif request.code == averaging_pb2.PART_FOR_AVERAGING: try: sender_index = self.sender_peer_ids.index(context.remote_id) if not self.should_delay_results(context.remote_id): async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index): yield msg else: done_receiving = asyncio.Event() delayed_results = asyncio.Queue() async def _accumulate_parts(): inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving) async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index): delayed_results.put_nowait(msg) delayed_results.put_nowait(None) accumulate_task = asyncio.create_task(_accumulate_parts()) await done_receiving.wait() while True: next_result = await delayed_results.get() if next_result is None: break yield next_result await accumulate_task except Exception as e: self.finalize(exception=e) yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR) else: error_code = averaging_pb2.MessageCode.Name(request.code) logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue") self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}.")) yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR) def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]: if request.group_id != self.group_id: return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID) elif self._future.cancelled(): return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED) elif self._future.done(): return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR) async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int): loop = asyncio.get_event_loop() async for part_index, (tensor_part, weight, part_compression) in aenumerate( amap_in_executor( lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression), stream, max_prefetch=self.tensor_part_container.prefetch, ) ): averaged_part = await self.tensor_part_reducer.accumulate_part( sender_index, part_index, tensor_part, weight=weight ) serialized_delta = await loop.run_in_executor( None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression) ) yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta) async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode): error = averaging_pb2.AveragingData(group_id=self.group_id, code=code) await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error))) def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None): """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers.""" assert not cancel or not exception, "finalize accepts either exception or cancel, but not both" pending_tasks = set() if cancel or exception: # propagate error to peers if cancel or isinstance(exception, asyncio.CancelledError): code = averaging_pb2.CANCELLED else: code = averaging_pb2.INTERNAL_ERROR logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}") for peer_id, mode in zip(self.ordered_peer_ids, self.modes): if peer_id != self.peer_id and mode != AveragingMode.CLIENT: pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code))) if not self._future.done(): if cancel: logger.debug(f"{self} - cancelled") self._future.cancel() elif exception: logger.debug(f"{self} - caught {exception}") self._future.set_exception(exception) else: logger.debug(f"{self} - finished") self._future.set_result(None) self.tensor_part_container.finalize() self.tensor_part_reducer.finalize() return pending_tasks else: logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}") return pending_tasks