|
@@ -1,6 +1,6 @@
|
|
|
import asyncio
|
|
|
from enum import Enum
|
|
|
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
|
|
|
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -11,8 +11,7 @@ from hivemind.proto import averaging_pb2
|
|
|
from hivemind.utils import get_logger
|
|
|
from hivemind.utils.asyncio import (
|
|
|
achain,
|
|
|
- aenumerate,
|
|
|
- afirst,
|
|
|
+ aiter_with_timeout,
|
|
|
amap_in_executor,
|
|
|
anext,
|
|
|
as_aiter,
|
|
@@ -52,6 +51,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
(the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
|
|
|
:param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
|
|
|
:param gathered: additional user-defined data collected from this group
|
|
|
+ :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
|
|
|
+ previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
|
|
|
+ :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
|
|
|
+ from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
|
|
|
:param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
|
|
|
:note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
|
|
|
non-averaging peers receive results only after they finish sending, which helps them avoid
|
|
@@ -71,11 +74,18 @@ class AllReduceRunner(ServicerBase):
|
|
|
peer_fractions: Tuple[float, ...],
|
|
|
modes: Optional[Sequence[AveragingMode]] = None,
|
|
|
gathered: Optional[Dict[PeerID, Any]] = None,
|
|
|
+ sender_timeout: Optional[float] = None,
|
|
|
+ reducer_timeout: Optional[float] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
self._p2p = p2p
|
|
|
self.peer_id = p2p.peer_id
|
|
|
assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
|
|
|
+ if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
|
|
|
+ raise ValueError(
|
|
|
+ "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
|
|
|
+ "Otherwise, there is a chance that reducers will be banned while they await senders."
|
|
|
+ )
|
|
|
|
|
|
if not issubclass(servicer_type, ServicerBase):
|
|
|
raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
|
|
@@ -102,8 +112,19 @@ class AllReduceRunner(ServicerBase):
|
|
|
if mode != AveragingMode.AUX:
|
|
|
self.sender_peer_ids.append(peer_id)
|
|
|
|
|
|
+ self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
|
|
|
+ self.all_senders_started = asyncio.Event()
|
|
|
+ self.banned_senders: Set[PeerID] = set() # peers that did not send data by next_chunk_timeout
|
|
|
+ self.banlock = asyncio.Lock()
|
|
|
+
|
|
|
+ self.active_senders: Set[PeerID] = set() # peers that began sending data via rpc_aggregate_part
|
|
|
+ if self.peer_id in self.sender_peer_ids:
|
|
|
+ self.active_senders.add(self.peer_id)
|
|
|
+ if len(self.active_senders) == len(self.sender_peer_ids):
|
|
|
+ self.all_senders_started.set()
|
|
|
+
|
|
|
peer_id_index = self.ordered_peer_ids.index(self.peer_id)
|
|
|
- self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
|
|
|
+ self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
|
|
|
self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
|
|
|
self.tensor_part_reducer = TensorPartReducer(
|
|
|
tuple(part.shape for part in self.parts_for_local_averaging),
|
|
@@ -132,6 +153,10 @@ class AllReduceRunner(ServicerBase):
|
|
|
async def run(self) -> AsyncIterator[torch.Tensor]:
|
|
|
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
|
|
|
pending_tasks = set()
|
|
|
+
|
|
|
+ if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
|
|
|
+ pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
|
|
|
+
|
|
|
try:
|
|
|
if len(self.sender_peer_ids) == 0:
|
|
|
logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
|
|
@@ -144,6 +169,7 @@ class AllReduceRunner(ServicerBase):
|
|
|
|
|
|
async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
|
|
|
yield averaged_tensor_delta # delta = averaged_tensor - original_tensor
|
|
|
+
|
|
|
self.finalize()
|
|
|
|
|
|
else: # auxiliary peer
|
|
@@ -156,6 +182,24 @@ class AllReduceRunner(ServicerBase):
|
|
|
task.cancel()
|
|
|
raise
|
|
|
|
|
|
+ finally:
|
|
|
+ for task in pending_tasks:
|
|
|
+ try:
|
|
|
+ await task
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ pass
|
|
|
+ except Exception as inner_exc:
|
|
|
+ logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
|
|
|
+
|
|
|
+ async def _handle_missing_senders(self):
|
|
|
+ """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
|
|
|
+ try:
|
|
|
+ await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ for peer_id in self.sender_peer_ids:
|
|
|
+ if peer_id not in self.active_senders and peer_id not in self.banned_senders:
|
|
|
+ await self._ban_sender(peer_id)
|
|
|
+
|
|
|
async def _communicate_with_peer(self, peer_id: PeerID):
|
|
|
"""Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
|
|
|
peer_index = self.ordered_peer_ids.index(peer_id)
|
|
@@ -168,25 +212,39 @@ class AllReduceRunner(ServicerBase):
|
|
|
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
|
|
|
|
|
|
else:
|
|
|
- code = None
|
|
|
- stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
|
|
|
- async for part_index, (averaged_part_delta, msg) in aenumerate(
|
|
|
- amap_in_executor(
|
|
|
- lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
|
|
|
- stream,
|
|
|
+ try:
|
|
|
+ done_sending = asyncio.Event()
|
|
|
+ inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
|
|
|
+ stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
|
|
|
+
|
|
|
+ if self.should_delay_results(self.peer_id):
|
|
|
+ await done_sending.wait()
|
|
|
+
|
|
|
+ part_index = 0
|
|
|
+
|
|
|
+ def _try_deserialize(msg):
|
|
|
+ if msg.code != averaging_pb2.AVERAGED_PART:
|
|
|
+ raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
|
|
|
+ return deserialize_torch_tensor(msg.tensor_part), msg
|
|
|
+
|
|
|
+ async for delta, msg in amap_in_executor(
|
|
|
+ _try_deserialize,
|
|
|
+ aiter_with_timeout(stream, self.reducer_timeout),
|
|
|
max_prefetch=self.tensor_part_container.prefetch,
|
|
|
- )
|
|
|
- ):
|
|
|
- if code is None:
|
|
|
- code = msg.code
|
|
|
- self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
|
|
|
-
|
|
|
- if code != averaging_pb2.AVERAGED_PART:
|
|
|
- raise AllreduceException(
|
|
|
- f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
|
|
|
- f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
|
|
|
- f", allreduce failed"
|
|
|
- )
|
|
|
+ ):
|
|
|
+ self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
|
|
|
+ part_index += 1
|
|
|
+
|
|
|
+ if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
|
|
|
+ raise AllreduceException(
|
|
|
+ f"peer {peer_id} sent {part_index} parts, but we expected "
|
|
|
+ f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
|
|
|
+ )
|
|
|
+ except BaseException as e:
|
|
|
+ if isinstance(e, Exception):
|
|
|
+ logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
|
|
|
+ self.tensor_part_container.register_failed_reducer(peer_index)
|
|
|
+ raise
|
|
|
|
|
|
async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
|
|
|
parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
|
|
@@ -204,18 +262,22 @@ class AllReduceRunner(ServicerBase):
|
|
|
self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
|
|
|
) -> AsyncIterator[averaging_pb2.AveragingData]:
|
|
|
"""a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
|
|
|
- request: averaging_pb2.AveragingData = await anext(stream)
|
|
|
- reason_to_reject = self._check_reasons_to_reject(request)
|
|
|
- if reason_to_reject:
|
|
|
- yield reason_to_reject
|
|
|
- return
|
|
|
-
|
|
|
- elif request.code == averaging_pb2.PART_FOR_AVERAGING:
|
|
|
- try:
|
|
|
- sender_index = self.sender_peer_ids.index(context.remote_id)
|
|
|
+ sender_index = self.sender_peer_ids.index(context.remote_id)
|
|
|
+ self.active_senders.add(context.remote_id)
|
|
|
+ if len(self.active_senders) == len(self.sender_peer_ids):
|
|
|
+ self.all_senders_started.set()
|
|
|
|
|
|
+ try:
|
|
|
+ request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
|
|
|
+ reason_to_reject = self._check_reasons_to_reject(request, context)
|
|
|
+ if reason_to_reject:
|
|
|
+ yield reason_to_reject
|
|
|
+ return
|
|
|
+
|
|
|
+ elif request.code == averaging_pb2.PART_FOR_AVERAGING:
|
|
|
+ stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
|
|
|
if not self.should_delay_results(context.remote_id):
|
|
|
- async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
|
|
|
+ async for msg in self._accumulate_parts_streaming(stream, sender_index):
|
|
|
yield msg
|
|
|
|
|
|
else:
|
|
@@ -223,10 +285,13 @@ class AllReduceRunner(ServicerBase):
|
|
|
delayed_results = asyncio.Queue()
|
|
|
|
|
|
async def _accumulate_parts():
|
|
|
- inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
|
|
|
- async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
|
|
|
- delayed_results.put_nowait(msg)
|
|
|
- delayed_results.put_nowait(None)
|
|
|
+ try:
|
|
|
+ async for msg in self._accumulate_parts_streaming(
|
|
|
+ attach_event_on_finished(stream, done_receiving), sender_index
|
|
|
+ ):
|
|
|
+ delayed_results.put_nowait(msg)
|
|
|
+ finally:
|
|
|
+ delayed_results.put_nowait(None)
|
|
|
|
|
|
accumulate_task = asyncio.create_task(_accumulate_parts())
|
|
|
|
|
@@ -239,63 +304,61 @@ class AllReduceRunner(ServicerBase):
|
|
|
yield next_result
|
|
|
await accumulate_task
|
|
|
|
|
|
- except Exception as e:
|
|
|
- self.finalize(exception=e)
|
|
|
+ else:
|
|
|
yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
- else:
|
|
|
- error_code = averaging_pb2.MessageCode.Name(request.code)
|
|
|
- logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
|
|
|
- self.finalize(exception=AllreduceException(f"Peer {context.remote_id} sent {error_code}"))
|
|
|
- yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
+ raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
|
|
|
+
|
|
|
+ except BaseException as e:
|
|
|
+ await self._ban_sender(context.remote_id)
|
|
|
+ if isinstance(e, Exception):
|
|
|
+ logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
|
|
|
+ yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
+ else:
|
|
|
+ raise # CancelledError, StopIteration and similar
|
|
|
+
|
|
|
+ async def _ban_sender(self, peer_id: PeerID):
|
|
|
+ async with self.banlock:
|
|
|
+ if peer_id not in self.banned_senders:
|
|
|
+ self.banned_senders.add(peer_id)
|
|
|
+ self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
|
|
|
|
|
|
- def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
|
|
|
+ def _check_reasons_to_reject(
|
|
|
+ self, request: averaging_pb2.AveragingData, context: P2PContext
|
|
|
+ ) -> Optional[averaging_pb2.AveragingData]:
|
|
|
if request.group_id != self.group_id:
|
|
|
return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
|
|
|
elif self._future.cancelled():
|
|
|
return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
|
|
|
elif self._future.done():
|
|
|
return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
|
|
|
+ elif context.remote_id not in self.sender_peer_ids:
|
|
|
+ return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
|
|
|
|
|
|
async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- async for part_index, (tensor_part, weight, part_compression) in aenumerate(
|
|
|
- amap_in_executor(
|
|
|
+ part_index = 0
|
|
|
+ try:
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ async for tensor_part, weight, part_compression in amap_in_executor(
|
|
|
lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
|
|
|
stream,
|
|
|
max_prefetch=self.tensor_part_container.prefetch,
|
|
|
- )
|
|
|
- ):
|
|
|
- averaged_part = await self.tensor_part_reducer.accumulate_part(
|
|
|
- sender_index, part_index, tensor_part, weight=weight
|
|
|
- )
|
|
|
-
|
|
|
- serialized_delta = await loop.run_in_executor(
|
|
|
- None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
|
|
|
- )
|
|
|
- yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
|
|
|
+ ):
|
|
|
+ averaged_part = await self.tensor_part_reducer.accumulate_part(
|
|
|
+ sender_index, part_index, tensor_part, weight=weight
|
|
|
+ )
|
|
|
+ part_index += 1
|
|
|
|
|
|
- async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
|
|
|
- try:
|
|
|
- error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
|
|
|
- await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
|
|
|
- except Exception as e:
|
|
|
- logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
|
|
|
+ serialized_delta = await loop.run_in_executor(
|
|
|
+ None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
|
|
|
+ )
|
|
|
+ yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
|
|
|
+ finally:
|
|
|
+ if part_index != self.tensor_part_reducer.num_parts:
|
|
|
+ await self._ban_sender(self.sender_peer_ids[sender_index])
|
|
|
|
|
|
def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
|
|
|
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
|
|
|
assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
|
|
|
- pending_tasks = set()
|
|
|
- if cancel or exception:
|
|
|
- # propagate error to peers
|
|
|
- if cancel or isinstance(exception, asyncio.CancelledError):
|
|
|
- code = averaging_pb2.CANCELLED
|
|
|
- else:
|
|
|
- code = averaging_pb2.INTERNAL_ERROR
|
|
|
- logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
|
|
|
- for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
|
|
|
- if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
|
|
|
- pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
|
|
|
-
|
|
|
if not self._future.done():
|
|
|
if cancel:
|
|
|
logger.debug(f"{self} - cancelled")
|
|
@@ -308,7 +371,5 @@ class AllReduceRunner(ServicerBase):
|
|
|
self._future.set_result(None)
|
|
|
self.tensor_part_container.finalize()
|
|
|
self.tensor_part_reducer.finalize()
|
|
|
- return pending_tasks
|
|
|
else:
|
|
|
- logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
|
|
|
- return pending_tasks
|
|
|
+ logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")
|