浏览代码

Improve All-Reduce fault-tolerance (#423)

- allow AllreduceRunner to tolerate clients that
   - do not send some of their local tensors
   - do not show up at at all after matchmaking is over
- allow AllreduceRunner to tolerate full/aux peers that do not send some or all results
- introduce timeout after which sender/reducer is considered failed
- AllreduceRunner & DecentralizedAverager will no longer _send_error_to_peer
   - log spam is gone!
- report allreduce integrity
   - TensorPartReducer will report the fraction of expected parts received if that fraction is not 1
   - TensorPartContainer will report the fraction of parts that did not fail if that fraction is not 1
- miscellaneous improvements to Optimizer
  - set good default sender/reducer timeouts
  - pre-schedule state averaging ahead of time
  - no longer block the entire peer if it is time to pre-schedule gradients but background state averaging is still underway

Test cases:
- test with peers that fail early
- test with peers that fail to send a certain part
- test with peers that fail to reduce their part
- test cancelling

Sanity checks:
- run tests 100 times
- benchmark_optimizer
- test env 64+ nodes 4+ hours

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 年之前
父节点
当前提交
6da8683975

+ 139 - 78
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
@@ -11,8 +11,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import (
     achain,
-    aenumerate,
-    afirst,
+    aiter_with_timeout,
     amap_in_executor,
     anext,
     as_aiter,
@@ -52,6 +51,10 @@ class AllReduceRunner(ServicerBase):
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param gathered: additional user-defined data collected from this group
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
     :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
       non-averaging peers receive results only after they finish sending, which helps them avoid
@@ -71,11 +74,18 @@ class AllReduceRunner(ServicerBase):
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         **kwargs,
     ):
         self._p2p = p2p
         self.peer_id = p2p.peer_id
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+        if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
+            raise ValueError(
+                "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
+                "Otherwise, there is a chance that reducers will be banned while they await senders."
+            )
 
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -102,8 +112,19 @@ class AllReduceRunner(ServicerBase):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
 
+        self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
+        self.all_senders_started = asyncio.Event()
+        self.banned_senders: Set[PeerID] = set()  # peers that did not send data by next_chunk_timeout
+        self.banlock = asyncio.Lock()
+
+        self.active_senders: Set[PeerID] = set()  # peers that began sending data via rpc_aggregate_part
+        if self.peer_id in self.sender_peer_ids:
+            self.active_senders.add(self.peer_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
+
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
-        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
@@ -132,6 +153,10 @@ class AllReduceRunner(ServicerBase):
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
+
+        if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
+            pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
+
         try:
             if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
@@ -144,6 +169,7 @@ class AllReduceRunner(ServicerBase):
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+
                 self.finalize()
 
             else:  # auxiliary peer
@@ -156,6 +182,24 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
             raise
 
+        finally:
+            for task in pending_tasks:
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                except Exception as inner_exc:
+                    logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
+
+    async def _handle_missing_senders(self):
+        """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
+        try:
+            await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
+        except asyncio.TimeoutError:
+            for peer_id in self.sender_peer_ids:
+                if peer_id not in self.active_senders and peer_id not in self.banned_senders:
+                    await self._ban_sender(peer_id)
+
     async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_peer_ids.index(peer_id)
@@ -168,25 +212,39 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            code = None
-            stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, (averaged_part_delta, msg) in aenumerate(
-                amap_in_executor(
-                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
-                    stream,
+            try:
+                done_sending = asyncio.Event()
+                inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
+                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+
+                if self.should_delay_results(self.peer_id):
+                    await done_sending.wait()
+
+                part_index = 0
+
+                def _try_deserialize(msg):
+                    if msg.code != averaging_pb2.AVERAGED_PART:
+                        raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
+                    return deserialize_torch_tensor(msg.tensor_part), msg
+
+                async for delta, msg in amap_in_executor(
+                    _try_deserialize,
+                    aiter_with_timeout(stream, self.reducer_timeout),
                     max_prefetch=self.tensor_part_container.prefetch,
-                )
-            ):
-                if code is None:
-                    code = msg.code
-                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-
-            if code != averaging_pb2.AVERAGED_PART:
-                raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
-                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                    f", allreduce failed"
-                )
+                ):
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
+                    part_index += 1
+
+                if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
+                    raise AllreduceException(
+                        f"peer {peer_id} sent {part_index} parts, but we expected "
+                        f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
+                    )
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
+                self.tensor_part_container.register_failed_reducer(peer_index)
+                raise
 
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
@@ -204,18 +262,22 @@ class AllReduceRunner(ServicerBase):
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
-        request: averaging_pb2.AveragingData = await anext(stream)
-        reason_to_reject = self._check_reasons_to_reject(request)
-        if reason_to_reject:
-            yield reason_to_reject
-            return
-
-        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
-            try:
-                sender_index = self.sender_peer_ids.index(context.remote_id)
+        sender_index = self.sender_peer_ids.index(context.remote_id)
+        self.active_senders.add(context.remote_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
+        try:
+            request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
+            reason_to_reject = self._check_reasons_to_reject(request, context)
+            if reason_to_reject:
+                yield reason_to_reject
+                return
+
+            elif request.code == averaging_pb2.PART_FOR_AVERAGING:
+                stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
                 if not self.should_delay_results(context.remote_id):
-                    async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
+                    async for msg in self._accumulate_parts_streaming(stream, sender_index):
                         yield msg
 
                 else:
@@ -223,10 +285,13 @@ class AllReduceRunner(ServicerBase):
                     delayed_results = asyncio.Queue()
 
                     async def _accumulate_parts():
-                        inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
-                        async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
-                            delayed_results.put_nowait(msg)
-                        delayed_results.put_nowait(None)
+                        try:
+                            async for msg in self._accumulate_parts_streaming(
+                                attach_event_on_finished(stream, done_receiving), sender_index
+                            ):
+                                delayed_results.put_nowait(msg)
+                        finally:
+                            delayed_results.put_nowait(None)
 
                     accumulate_task = asyncio.create_task(_accumulate_parts())
 
@@ -239,63 +304,61 @@ class AllReduceRunner(ServicerBase):
                         yield next_result
                     await accumulate_task
 
-            except Exception as e:
-                self.finalize(exception=e)
+            else:
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-        else:
-            error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"Peer {context.remote_id} sent {error_code}"))
-            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
+
+        except BaseException as e:
+            await self._ban_sender(context.remote_id)
+            if isinstance(e, Exception):
+                logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            else:
+                raise  # CancelledError, StopIteration and similar
+
+    async def _ban_sender(self, peer_id: PeerID):
+        async with self.banlock:
+            if peer_id not in self.banned_senders:
+                self.banned_senders.add(peer_id)
+                self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.AveragingData, context: P2PContext
+    ) -> Optional[averaging_pb2.AveragingData]:
         if request.group_id != self.group_id:
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
         elif self._future.cancelled():
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
         elif self._future.done():
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        elif context.remote_id not in self.sender_peer_ids:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
-        loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
-            amap_in_executor(
+        part_index = 0
+        try:
+            loop = asyncio.get_event_loop()
+            async for tensor_part, weight, part_compression in amap_in_executor(
                 lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
-            )
-        ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(
-                sender_index, part_index, tensor_part, weight=weight
-            )
-
-            serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
-            )
-            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+            ):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=weight
+                )
+                part_index += 1
 
-    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        try:
-            error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-            await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
-        except Exception as e:
-            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
+                serialized_delta = await loop.run_in_executor(
+                    None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+                )
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+        finally:
+            if part_index != self.tensor_part_reducer.num_parts:
+                await self._ban_sender(self.sender_peer_ids[sender_index])
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
-        pending_tasks = set()
-        if cancel or exception:
-            # propagate error to peers
-            if cancel or isinstance(exception, asyncio.CancelledError):
-                code = averaging_pb2.CANCELLED
-            else:
-                code = averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
-                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
-
         if not self._future.done():
             if cancel:
                 logger.debug(f"{self} - cancelled")
@@ -308,7 +371,5 @@ class AllReduceRunner(ServicerBase):
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
-            return pending_tasks
         else:
-            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return pending_tasks
+            logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")

+ 29 - 26
hivemind/averaging/averager.py

@@ -70,7 +70,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
-    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
@@ -87,6 +86,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
     :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
+      this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
@@ -124,6 +130,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
@@ -154,6 +163,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         if client_mode is None:
             client_mode = dht.client_mode
+        if sender_timeout is None:
+            sender_timeout = next_chunk_timeout
+        if reducer_timeout is None:
+            reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
+
         self.client_mode = client_mode
 
         self._parent_pid = os.getpid()
@@ -173,6 +187,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
         self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
@@ -188,6 +203,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             compression=compression,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
+            sender_timeout=sender_timeout,
+            reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -417,20 +434,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             async def find_peers_or_notify_cancel():
                 group_info = await self._matchmaking.look_for_group(step)
-                try:
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-                    return group_info
-                except asyncio.CancelledError:
-                    await asyncio.wait(
-                        {
-                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                            for peer_id in group_info.peer_ids
-                            if peer_id != self.peer_id
-                        }
-                    )
-                    raise
+                if not step.triggered:
+                    step.stage = AveragingStage.AWAITING_TRIGGER
+                    await step.wait_for_trigger()
+                return group_info
 
             while not step.done():
                 try:
@@ -496,14 +503,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
-    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
-        try:
-            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
-            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
-            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
-        except Exception as e:
-            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
-
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
@@ -535,7 +534,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                        iter_results = allreduce.run()
+                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                         self._state_updated.set()
@@ -546,7 +546,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                 return allreduce.gathered
         except BaseException as e:
-            logger.exception(e)
+            if isinstance(e, Exception):
+                logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
     @contextlib.contextmanager
@@ -680,6 +681,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         return future.result(timeout=timeout) if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
+        if timeout is not None:
+            timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
@@ -703,7 +706,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
-                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:

+ 55 - 17
hivemind/averaging/partition.py

@@ -10,21 +10,24 @@ import torch
 
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor
+from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
+logger = get_logger(__name__)
 
 
 class TensorPartContainer:
     """
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
+
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+    :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -35,6 +38,7 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        return_deltas: bool = True,
         prefetch: int = 1,
     ):
         if tensor_infos is None:
@@ -43,6 +47,8 @@ class TensorPartContainer:
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.failed_size = 0
+        self.return_deltas = return_deltas
         self.prefetch = prefetch
 
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
@@ -91,7 +97,6 @@ class TensorPartContainer:
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
-        self._input_parts_by_peer[peer_index].clear()
         return input_parts
 
     @torch.no_grad()
@@ -99,13 +104,9 @@ class TensorPartContainer:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
-
-        async def _aiterate_parts():
-            for _ in range(self.num_parts_by_peer[peer_index]):
-                yield self._input_parts_by_peer[peer_index].popleft()
-
+        parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
         async for serialized_part in amap_in_executor(
-            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
         ):
             yield serialized_part
 
@@ -123,6 +124,16 @@ class TensorPartContainer:
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
 
+    def register_failed_reducer(self, peer_index: int):
+        """
+        a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
+        """
+        for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
+            part_and_info = self._input_parts_by_peer[peer_index][part_index]
+            part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
+            self.register_processed_part(peer_index, part_index, part_result_or_delta)
+            self.failed_size += part_result_or_delta.numel()
+
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
@@ -155,9 +166,11 @@ class TensorPartContainer:
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
+                self._output_part_available[peer_index].set()
                 self._input_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
-                self._output_part_available[peer_index].set()
+            if self.failed_size != 0:
+                logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
             self._outputs_consumed = True
             self.finished.set()
 
@@ -178,11 +191,16 @@ class TensorPartReducer:
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
+
+        self.num_parts_received = [0 for _ in range(self.num_senders)]
+        self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
+        self.num_current_senders = self.num_senders
+
         self.reset_accumulators()
 
     def reset_accumulators(self):
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
-        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
             return
@@ -190,6 +208,9 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
+        self.num_current_senders = sum(
+            self.current_part_index < failed_index for failed_index in self.sender_failed_after
+        )
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
@@ -199,6 +220,7 @@ class TensorPartReducer:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
+        self.num_parts_received[sender_index] += 1
 
         while part_index > self.current_part_index:
             # wait for previous parts to finish processing ...
@@ -209,15 +231,25 @@ class TensorPartReducer:
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=weight)
-        self.current_part_accumulated_from += 1
-        self.denominator += weight
+        if part_index < self.sender_failed_after[sender_index]:
+            self.accumulator.add_(tensor_part, alpha=weight)
+            self.current_part_accumulated_from += 1
+            self.denominator += weight
+            self.check_current_part_finished()
+        return await current_part_future
 
-        assert self.current_part_accumulated_from <= self.num_senders
-        if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+    def on_sender_failed(self, sender_index: int):
+        """Exclude that sender's data for averaging any parts that it did not submit yet."""
+        self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.current_part_index == self.num_parts_received[sender_index]:
+            self.num_current_senders -= 1
+            self.check_current_part_finished()
+
+    def check_current_part_finished(self):
+        assert self.current_part_accumulated_from <= self.num_current_senders
+        if self.current_part_accumulated_from == self.num_current_senders:
+            self.current_part_future.set_result(self.accumulator.div_(self.denominator))
             self.reset_accumulators()
-        return await current_part_future
 
     def finalize(self):
         if not self.finished.is_set():
@@ -226,6 +258,12 @@ class TensorPartReducer:
                 del self.accumulator
             self.finished.set()
 
+            if self.num_parts != 0 and self.num_senders != 0:
+                parts_expected = self.num_parts * self.num_senders
+                parts_received = sum(self.num_parts_received)
+                if parts_expected != parts_received:
+                    logger.info(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+
     def __del__(self):
         self.finalize()
 

+ 10 - 5
hivemind/optim/experimental/optimizer.py

@@ -175,6 +175,7 @@ class Optimizer(torch.optim.Optimizer):
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
@@ -200,6 +201,7 @@ class Optimizer(torch.optim.Optimizer):
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
+        next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -230,6 +232,7 @@ class Optimizer(torch.optim.Optimizer):
 
         self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
         self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
@@ -279,6 +282,7 @@ class Optimizer(torch.optim.Optimizer):
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
             status_loglevel=self.status_loglevel,
+            next_chunk_timeout=self.next_chunk_timeout,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             start=True,
@@ -294,6 +298,7 @@ class Optimizer(torch.optim.Optimizer):
             min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
+            next_chunk_timeout=self.next_chunk_timeout,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             start=True,
@@ -427,6 +432,9 @@ class Optimizer(torch.optim.Optimizer):
 
             if self.use_gradient_averaging:
                 logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                if self.delay_optimizer_step:
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 if not began_averaging_gradients:
                     pass  # failed to start gradient averaging due to an internal error
@@ -534,10 +542,6 @@ class Optimizer(torch.optim.Optimizer):
         assert self.use_gradient_averaging
         if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
             if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
-                if self.delay_grad_averaging:
-                    # wait for previous averaging to finish before starting a new one
-                    self.state_averager.step(wait_for_delayed_updates=True)
-
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
@@ -545,12 +549,13 @@ class Optimizer(torch.optim.Optimizer):
 
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
-        return
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
         if self.state_averager.averaging_in_progress:
             return  # previous run is still in progress
+        if self.delay_before_state_averaging.num_updates == 0:
+            return  # not enough data to accurately pre-schedule
 
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample

+ 2 - 3
hivemind/optim/grad_scaler.py

@@ -60,11 +60,11 @@ class GradScaler(TorchGradScaler):
                 return False
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
-        if self._is_running_global_step:
+        if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
+            # ^-- invoked privately within hivemind optimizer
             with self._lock:
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step")
-                assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
                 assert (
                     self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
                 ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
@@ -75,7 +75,6 @@ class GradScaler(TorchGradScaler):
                 self._is_ready_to_update = True
                 return True
         else:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False

+ 23 - 9
hivemind/utils/asyncio.py

@@ -114,9 +114,15 @@ async def amap_in_executor(
     queue = asyncio.Queue(max_prefetch)
 
     async def _put_items():
-        async for args in azip(*iterables):
-            await queue.put(loop.run_in_executor(executor, func, *args))
-        await queue.put(None)
+        try:
+            async for args in azip(*iterables):
+                await queue.put(loop.run_in_executor(executor, func, *args))
+            await queue.put(None)
+        except BaseException as e:
+            future = asyncio.Future()
+            future.set_exception(e)
+            await queue.put(future)
+            raise
 
     task = asyncio.create_task(_put_items())
     try:
@@ -124,13 +130,21 @@ async def amap_in_executor(
         while future is not None:
             yield await future
             future = await queue.get()
-        await task
     finally:
-        if not task.done():
-            task.cancel()
-
-
-async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+        task.cancel()
+        try:
+            await task
+        except asyncio.CancelledError:
+            pass
+        except Exception as e:
+            logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+        while not queue.empty():
+            future = queue.get_nowait()
+            if future is not None:
+                future.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
     """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
     # based on https://stackoverflow.com/a/50245879
     iterator = iterable.__aiter__()

+ 213 - 0
tests/test_allreduce_fault_tolerance.py

@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import asyncio
+from enum import Enum, auto
+from typing import AsyncIterator
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.averager import *
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.proto import averaging_pb2
+from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Fault(Enum):
+    NONE = auto()
+    FAIL_BEFORE = auto()
+    FAIL_SENDING = auto()
+    SLOW_SENDING = auto()
+    FAIL_REDUCING = auto()
+    SLOW_REDUCING = auto()
+    CANCEL = auto()
+
+
+class FaultyAverager(hivemind.DecentralizedAverager):
+    def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            if self.fault == Fault.FAIL_BEFORE:
+                raise Exception("Oops, I failed!")
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                allreduce = FaultyAllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    fault=self.fault,
+                    **kwargs,
+                )
+
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+
+class FaultyAllReduceRunner(AllReduceRunner):
+    def __init__(self, *args, fault: Fault, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
+        if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
+            async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
+                yield message
+                if i == 2:
+                    if self.fault == Fault.FAIL_SENDING:
+                        yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                        break
+                    else:
+                        await asyncio.sleep(10)
+
+        elif self.fault == Fault.CANCEL:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        else:
+            async for message in super().rpc_aggregate_part(stream, context):
+                yield message
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+
+        first_part = await anext(parts_aiter)
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
+            weight=self.weight,
+        )
+        if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
+            last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
+            if peer_index == last_reducer_index:
+                if self.fault == Fault.FAIL_SENDING:
+                    raise Exception("Oops, I failed!")
+                else:
+                    await asyncio.sleep(10)
+        async for part in parts_aiter:
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "fault0, fault1",
+    [
+        (Fault.NONE, Fault.FAIL_BEFORE),
+        (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
+        (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
+        (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
+        (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
+        (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
+        (Fault.NONE, Fault.CANCEL),
+    ],
+)
+def test_fault_tolerance(fault0: Fault, fault1: Fault):
+    def _make_tensors():
+        return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
+
+    dht = hivemind.DHT(start=True)
+
+    averagers = []
+    for i in range(5):
+        averager = FaultyAverager(
+            _make_tensors(),
+            hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
+            prefix="test",
+            request_timeout=0.3,
+            min_matchmaking_time=1.0,
+            next_chunk_timeout=0.5,
+            allreduce_timeout=5,
+            part_size_bytes=2 ** 16,
+            client_mode=(i == 1),
+            start=True,
+            fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
+        )
+        averagers.append(averager)
+
+    ref_numerators = [0, 0, 0, 0]
+    ref_denominator = 0
+
+    for averager in averagers:
+        if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
+            with averager.get_tensors() as tensors:
+                for i, tensor in enumerate(tensors):
+                    ref_numerators[i] = ref_numerators[i] + tensor.clone()
+                ref_denominator += 1
+
+    ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
+    flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
+
+    flat_local_tensors = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
+
+    futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
+    for i, averager in enumerate(averagers):
+        if averager.fault == Fault.CANCEL:
+            futures[i].cancel()
+
+    for future in futures[2:]:
+        assert future.result()
+
+    for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
+        with averager.get_tensors() as tensors:
+            flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
+
+        diff_with_reference = abs(flat_ref - flat_tensors)
+
+        if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
+            assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
+            assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
+        elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
+            diff_to_reference = abs(flat_ref - flat_tensors)
+            diff_to_local = abs(prev_local_tensors - flat_tensors)
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+            assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
+        elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
+            pass  # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
+        elif fault0 == Fault.NONE:  # only peer1 in client mode may have failed
+            assert diff_with_reference.max() < 1e-5
+        else:
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+
+    for averager in averagers:
+        averager.shutdown()