Browse Source

Improve All-Reduce fault-tolerance (#423)

- allow AllreduceRunner to tolerate clients that
   - do not send some of their local tensors
   - do not show up at at all after matchmaking is over
- allow AllreduceRunner to tolerate full/aux peers that do not send some or all results
- introduce timeout after which sender/reducer is considered failed
- AllreduceRunner & DecentralizedAverager will no longer _send_error_to_peer
   - log spam is gone!
- report allreduce integrity
   - TensorPartReducer will report the fraction of expected parts received if that fraction is not 1
   - TensorPartContainer will report the fraction of parts that did not fail if that fraction is not 1
- miscellaneous improvements to Optimizer
  - set good default sender/reducer timeouts
  - pre-schedule state averaging ahead of time
  - no longer block the entire peer if it is time to pre-schedule gradients but background state averaging is still underway

Test cases:
- test with peers that fail early
- test with peers that fail to send a certain part
- test with peers that fail to reduce their part
- test cancelling

Sanity checks:
- run tests 100 times
- benchmark_optimizer
- test env 64+ nodes 4+ hours

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 3 years ago
parent
commit
6da8683975

+ 139 - 78
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 from enum import Enum
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
 
 
 import torch
 import torch
 
 
@@ -11,8 +11,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
-    aenumerate,
-    afirst,
+    aiter_with_timeout,
     amap_in_executor,
     amap_in_executor,
     anext,
     anext,
     as_aiter,
     as_aiter,
@@ -52,6 +51,10 @@ class AllReduceRunner(ServicerBase):
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param gathered: additional user-defined data collected from this group
     :param gathered: additional user-defined data collected from this group
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
     :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
     :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
       non-averaging peers receive results only after they finish sending, which helps them avoid
       non-averaging peers receive results only after they finish sending, which helps them avoid
@@ -71,11 +74,18 @@ class AllReduceRunner(ServicerBase):
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         self._p2p = p2p
         self._p2p = p2p
         self.peer_id = p2p.peer_id
         self.peer_id = p2p.peer_id
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+        if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
+            raise ValueError(
+                "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
+                "Otherwise, there is a chance that reducers will be banned while they await senders."
+            )
 
 
         if not issubclass(servicer_type, ServicerBase):
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -102,8 +112,19 @@ class AllReduceRunner(ServicerBase):
             if mode != AveragingMode.AUX:
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
                 self.sender_peer_ids.append(peer_id)
 
 
+        self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
+        self.all_senders_started = asyncio.Event()
+        self.banned_senders: Set[PeerID] = set()  # peers that did not send data by next_chunk_timeout
+        self.banlock = asyncio.Lock()
+
+        self.active_senders: Set[PeerID] = set()  # peers that began sending data via rpc_aggregate_part
+        if self.peer_id in self.sender_peer_ids:
+            self.active_senders.add(self.peer_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
+
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
-        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
@@ -132,6 +153,10 @@ class AllReduceRunner(ServicerBase):
     async def run(self) -> AsyncIterator[torch.Tensor]:
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
         pending_tasks = set()
+
+        if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
+            pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
+
         try:
         try:
             if len(self.sender_peer_ids) == 0:
             if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
@@ -144,6 +169,7 @@ class AllReduceRunner(ServicerBase):
 
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+
                 self.finalize()
                 self.finalize()
 
 
             else:  # auxiliary peer
             else:  # auxiliary peer
@@ -156,6 +182,24 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
                 task.cancel()
             raise
             raise
 
 
+        finally:
+            for task in pending_tasks:
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                except Exception as inner_exc:
+                    logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
+
+    async def _handle_missing_senders(self):
+        """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
+        try:
+            await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
+        except asyncio.TimeoutError:
+            for peer_id in self.sender_peer_ids:
+                if peer_id not in self.active_senders and peer_id not in self.banned_senders:
+                    await self._ban_sender(peer_id)
+
     async def _communicate_with_peer(self, peer_id: PeerID):
     async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_peer_ids.index(peer_id)
         peer_index = self.ordered_peer_ids.index(peer_id)
@@ -168,25 +212,39 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
 
         else:
         else:
-            code = None
-            stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, (averaged_part_delta, msg) in aenumerate(
-                amap_in_executor(
-                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
-                    stream,
+            try:
+                done_sending = asyncio.Event()
+                inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
+                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+
+                if self.should_delay_results(self.peer_id):
+                    await done_sending.wait()
+
+                part_index = 0
+
+                def _try_deserialize(msg):
+                    if msg.code != averaging_pb2.AVERAGED_PART:
+                        raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
+                    return deserialize_torch_tensor(msg.tensor_part), msg
+
+                async for delta, msg in amap_in_executor(
+                    _try_deserialize,
+                    aiter_with_timeout(stream, self.reducer_timeout),
                     max_prefetch=self.tensor_part_container.prefetch,
                     max_prefetch=self.tensor_part_container.prefetch,
-                )
-            ):
-                if code is None:
-                    code = msg.code
-                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-
-            if code != averaging_pb2.AVERAGED_PART:
-                raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
-                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                    f", allreduce failed"
-                )
+                ):
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
+                    part_index += 1
+
+                if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
+                    raise AllreduceException(
+                        f"peer {peer_id} sent {part_index} parts, but we expected "
+                        f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
+                    )
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
+                self.tensor_part_container.register_failed_reducer(peer_index)
+                raise
 
 
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
@@ -204,18 +262,22 @@ class AllReduceRunner(ServicerBase):
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
-        request: averaging_pb2.AveragingData = await anext(stream)
-        reason_to_reject = self._check_reasons_to_reject(request)
-        if reason_to_reject:
-            yield reason_to_reject
-            return
-
-        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
-            try:
-                sender_index = self.sender_peer_ids.index(context.remote_id)
+        sender_index = self.sender_peer_ids.index(context.remote_id)
+        self.active_senders.add(context.remote_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
 
+        try:
+            request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
+            reason_to_reject = self._check_reasons_to_reject(request, context)
+            if reason_to_reject:
+                yield reason_to_reject
+                return
+
+            elif request.code == averaging_pb2.PART_FOR_AVERAGING:
+                stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
                 if not self.should_delay_results(context.remote_id):
                 if not self.should_delay_results(context.remote_id):
-                    async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
+                    async for msg in self._accumulate_parts_streaming(stream, sender_index):
                         yield msg
                         yield msg
 
 
                 else:
                 else:
@@ -223,10 +285,13 @@ class AllReduceRunner(ServicerBase):
                     delayed_results = asyncio.Queue()
                     delayed_results = asyncio.Queue()
 
 
                     async def _accumulate_parts():
                     async def _accumulate_parts():
-                        inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
-                        async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
-                            delayed_results.put_nowait(msg)
-                        delayed_results.put_nowait(None)
+                        try:
+                            async for msg in self._accumulate_parts_streaming(
+                                attach_event_on_finished(stream, done_receiving), sender_index
+                            ):
+                                delayed_results.put_nowait(msg)
+                        finally:
+                            delayed_results.put_nowait(None)
 
 
                     accumulate_task = asyncio.create_task(_accumulate_parts())
                     accumulate_task = asyncio.create_task(_accumulate_parts())
 
 
@@ -239,63 +304,61 @@ class AllReduceRunner(ServicerBase):
                         yield next_result
                         yield next_result
                     await accumulate_task
                     await accumulate_task
 
 
-            except Exception as e:
-                self.finalize(exception=e)
+            else:
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-        else:
-            error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"Peer {context.remote_id} sent {error_code}"))
-            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
+
+        except BaseException as e:
+            await self._ban_sender(context.remote_id)
+            if isinstance(e, Exception):
+                logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            else:
+                raise  # CancelledError, StopIteration and similar
+
+    async def _ban_sender(self, peer_id: PeerID):
+        async with self.banlock:
+            if peer_id not in self.banned_senders:
+                self.banned_senders.add(peer_id)
+                self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
 
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.AveragingData, context: P2PContext
+    ) -> Optional[averaging_pb2.AveragingData]:
         if request.group_id != self.group_id:
         if request.group_id != self.group_id:
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
         elif self._future.cancelled():
         elif self._future.cancelled():
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
         elif self._future.done():
         elif self._future.done():
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        elif context.remote_id not in self.sender_peer_ids:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
 
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
-        loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
-            amap_in_executor(
+        part_index = 0
+        try:
+            loop = asyncio.get_event_loop()
+            async for tensor_part, weight, part_compression in amap_in_executor(
                 lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
                 max_prefetch=self.tensor_part_container.prefetch,
-            )
-        ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(
-                sender_index, part_index, tensor_part, weight=weight
-            )
-
-            serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
-            )
-            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+            ):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=weight
+                )
+                part_index += 1
 
 
-    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        try:
-            error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-            await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
-        except Exception as e:
-            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
+                serialized_delta = await loop.run_in_executor(
+                    None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+                )
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+        finally:
+            if part_index != self.tensor_part_reducer.num_parts:
+                await self._ban_sender(self.sender_peer_ids[sender_index])
 
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
-        pending_tasks = set()
-        if cancel or exception:
-            # propagate error to peers
-            if cancel or isinstance(exception, asyncio.CancelledError):
-                code = averaging_pb2.CANCELLED
-            else:
-                code = averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
-                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
-
         if not self._future.done():
         if not self._future.done():
             if cancel:
             if cancel:
                 logger.debug(f"{self} - cancelled")
                 logger.debug(f"{self} - cancelled")
@@ -308,7 +371,5 @@ class AllReduceRunner(ServicerBase):
                 self._future.set_result(None)
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
             self.tensor_part_reducer.finalize()
-            return pending_tasks
         else:
         else:
-            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return pending_tasks
+            logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")

+ 29 - 26
hivemind/averaging/averager.py

@@ -70,7 +70,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
-    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
@@ -87,6 +86,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
       with averager.allow_state_sharing = True / False
     :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
     :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
+      this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
 
     Example:
     Example:
@@ -124,6 +130,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         compression: CompressionBase = NoCompression(),
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
@@ -154,6 +163,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
         if client_mode is None:
         if client_mode is None:
             client_mode = dht.client_mode
             client_mode = dht.client_mode
+        if sender_timeout is None:
+            sender_timeout = next_chunk_timeout
+        if reducer_timeout is None:
+            reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
+
         self.client_mode = client_mode
         self.client_mode = client_mode
 
 
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
@@ -173,6 +187,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
         self.bandwidth = bandwidth
         self.bandwidth = bandwidth
 
 
         self.matchmaking_kwargs = dict(
         self.matchmaking_kwargs = dict(
@@ -188,6 +203,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             compression=compression,
             compression=compression,
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
             min_vector_size=min_vector_size,
+            sender_timeout=sender_timeout,
+            reducer_timeout=reducer_timeout,
         )
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -417,20 +434,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             async def find_peers_or_notify_cancel():
             async def find_peers_or_notify_cancel():
                 group_info = await self._matchmaking.look_for_group(step)
                 group_info = await self._matchmaking.look_for_group(step)
-                try:
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-                    return group_info
-                except asyncio.CancelledError:
-                    await asyncio.wait(
-                        {
-                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
-                            for peer_id in group_info.peer_ids
-                            if peer_id != self.peer_id
-                        }
-                    )
-                    raise
+                if not step.triggered:
+                    step.stage = AveragingStage.AWAITING_TRIGGER
+                    await step.wait_for_trigger()
+                return group_info
 
 
             while not step.done():
             while not step.done():
                 try:
                 try:
@@ -496,14 +503,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                     )
                 )
                 )
 
 
-    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
-        try:
-            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
-            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
-            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
-        except Exception as e:
-            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}")
-
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
@@ -535,7 +534,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
                 with self.register_allreduce_group(group_info.group_id, allreduce):
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                        iter_results = allreduce.run()
+                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
                             # all-reduce is performed asynchronously while iterating
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
                             tensor.add_(update, alpha=self._averaging_alpha)
                         self._state_updated.set()
                         self._state_updated.set()
@@ -546,7 +546,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
                 return allreduce.gathered
                 return allreduce.gathered
         except BaseException as e:
         except BaseException as e:
-            logger.exception(e)
+            if isinstance(e, Exception):
+                logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
@@ -680,6 +681,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         return future.result(timeout=timeout) if wait else future
         return future.result(timeout=timeout) if wait else future
 
 
     async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
     async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
+        if timeout is not None:
+            timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
         try:
         try:
             key_manager = self._matchmaking.group_key_manager
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
@@ -703,7 +706,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
                         current_tensor_parts, tensors = [], []
 
 
-                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
                             if message.tensor_part.dtype and current_tensor_parts:

+ 55 - 17
hivemind/averaging/partition.py

@@ -10,21 +10,24 @@ import torch
 
 
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor
+from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
+logger = get_logger(__name__)
 
 
 
 
 class TensorPartContainer:
 class TensorPartContainer:
     """
     """
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
+
     :param tensors: local tensors to be split and aggregated
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+    :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
     """
 
 
@@ -35,6 +38,7 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        return_deltas: bool = True,
         prefetch: int = 1,
         prefetch: int = 1,
     ):
     ):
         if tensor_infos is None:
         if tensor_infos is None:
@@ -43,6 +47,8 @@ class TensorPartContainer:
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.failed_size = 0
+        self.return_deltas = return_deltas
         self.prefetch = prefetch
         self.prefetch = prefetch
 
 
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
@@ -91,7 +97,6 @@ class TensorPartContainer:
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
-        self._input_parts_by_peer[peer_index].clear()
         return input_parts
         return input_parts
 
 
     @torch.no_grad()
     @torch.no_grad()
@@ -99,13 +104,9 @@ class TensorPartContainer:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         self._inputs_consumed_by_peer[peer_index] = True
-
-        async def _aiterate_parts():
-            for _ in range(self.num_parts_by_peer[peer_index]):
-                yield self._input_parts_by_peer[peer_index].popleft()
-
+        parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
         async for serialized_part in amap_in_executor(
         async for serialized_part in amap_in_executor(
-            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
         ):
         ):
             yield serialized_part
             yield serialized_part
 
 
@@ -123,6 +124,16 @@ class TensorPartContainer:
         self._outputs_registered_by_peer[peer_index] += 1
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
         self._output_part_available[peer_index].set()
 
 
+    def register_failed_reducer(self, peer_index: int):
+        """
+        a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
+        """
+        for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
+            part_and_info = self._input_parts_by_peer[peer_index][part_index]
+            part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
+            self.register_processed_part(peer_index, part_index, part_result_or_delta)
+            self.failed_size += part_result_or_delta.numel()
+
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
@@ -155,9 +166,11 @@ class TensorPartContainer:
         if not self.finished.is_set():
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
                 self._inputs_consumed_by_peer[peer_index] = True
+                self._output_part_available[peer_index].set()
                 self._input_parts_by_peer[peer_index].clear()
                 self._input_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
-                self._output_part_available[peer_index].set()
+            if self.failed_size != 0:
+                logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
             self._outputs_consumed = True
             self._outputs_consumed = True
             self.finished.set()
             self.finished.set()
 
 
@@ -178,11 +191,16 @@ class TensorPartReducer:
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
         self.finished = asyncio.Event()
+
+        self.num_parts_received = [0 for _ in range(self.num_senders)]
+        self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
+        self.num_current_senders = self.num_senders
+
         self.reset_accumulators()
         self.reset_accumulators()
 
 
     def reset_accumulators(self):
     def reset_accumulators(self):
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
-        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
             self.finalize()
             return
             return
@@ -190,6 +208,9 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
         self.current_part_future = asyncio.Future()
+        self.num_current_senders = sum(
+            self.current_part_index < failed_index for failed_index in self.sender_failed_after
+        )
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
         self.denominator = 0.0
 
 
@@ -199,6 +220,7 @@ class TensorPartReducer:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
+        self.num_parts_received[sender_index] += 1
 
 
         while part_index > self.current_part_index:
         while part_index > self.current_part_index:
             # wait for previous parts to finish processing ...
             # wait for previous parts to finish processing ...
@@ -209,15 +231,25 @@ class TensorPartReducer:
 
 
         current_part_future = self.current_part_future
         current_part_future = self.current_part_future
 
 
-        self.accumulator.add_(tensor_part, alpha=weight)
-        self.current_part_accumulated_from += 1
-        self.denominator += weight
+        if part_index < self.sender_failed_after[sender_index]:
+            self.accumulator.add_(tensor_part, alpha=weight)
+            self.current_part_accumulated_from += 1
+            self.denominator += weight
+            self.check_current_part_finished()
+        return await current_part_future
 
 
-        assert self.current_part_accumulated_from <= self.num_senders
-        if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+    def on_sender_failed(self, sender_index: int):
+        """Exclude that sender's data for averaging any parts that it did not submit yet."""
+        self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.current_part_index == self.num_parts_received[sender_index]:
+            self.num_current_senders -= 1
+            self.check_current_part_finished()
+
+    def check_current_part_finished(self):
+        assert self.current_part_accumulated_from <= self.num_current_senders
+        if self.current_part_accumulated_from == self.num_current_senders:
+            self.current_part_future.set_result(self.accumulator.div_(self.denominator))
             self.reset_accumulators()
             self.reset_accumulators()
-        return await current_part_future
 
 
     def finalize(self):
     def finalize(self):
         if not self.finished.is_set():
         if not self.finished.is_set():
@@ -226,6 +258,12 @@ class TensorPartReducer:
                 del self.accumulator
                 del self.accumulator
             self.finished.set()
             self.finished.set()
 
 
+            if self.num_parts != 0 and self.num_senders != 0:
+                parts_expected = self.num_parts * self.num_senders
+                parts_received = sum(self.num_parts_received)
+                if parts_expected != parts_received:
+                    logger.info(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+
     def __del__(self):
     def __del__(self):
         self.finalize()
         self.finalize()
 
 

+ 10 - 5
hivemind/optim/experimental/optimizer.py

@@ -175,6 +175,7 @@ class Optimizer(torch.optim.Optimizer):
         matchmaking_time: Optional[float] = 15.0,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 60.0,
         averaging_timeout: Optional[float] = 60.0,
         allreduce_timeout: Optional[float] = None,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         offload_optimizer: Optional[bool] = None,
         offload_optimizer: Optional[bool] = None,
@@ -200,6 +201,7 @@ class Optimizer(torch.optim.Optimizer):
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
+        next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
@@ -230,6 +232,7 @@ class Optimizer(torch.optim.Optimizer):
 
 
         self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
         self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
         self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
         self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
 
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.scheduled_grads: Optional[StepControl] = None
         self.scheduled_grads: Optional[StepControl] = None
@@ -279,6 +282,7 @@ class Optimizer(torch.optim.Optimizer):
             offload_optimizer=self.offload_optimizer,
             offload_optimizer=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
             custom_gradients=self.offload_optimizer,
             status_loglevel=self.status_loglevel,
             status_loglevel=self.status_loglevel,
+            next_chunk_timeout=self.next_chunk_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             auxiliary=self.auxiliary,
             start=True,
             start=True,
@@ -294,6 +298,7 @@ class Optimizer(torch.optim.Optimizer):
             min_matchmaking_time=self.matchmaking_time,
             min_matchmaking_time=self.matchmaking_time,
             allreduce_timeout=self.allreduce_timeout,
             allreduce_timeout=self.allreduce_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
+            next_chunk_timeout=self.next_chunk_timeout,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             auxiliary=self.auxiliary,
             start=True,
             start=True,
@@ -427,6 +432,9 @@ class Optimizer(torch.optim.Optimizer):
 
 
             if self.use_gradient_averaging:
             if self.use_gradient_averaging:
                 logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
                 logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                if self.delay_optimizer_step:
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
                 if not began_averaging_gradients:
                 if not began_averaging_gradients:
                     pass  # failed to start gradient averaging due to an internal error
                     pass  # failed to start gradient averaging due to an internal error
@@ -534,10 +542,6 @@ class Optimizer(torch.optim.Optimizer):
         assert self.use_gradient_averaging
         assert self.use_gradient_averaging
         if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
         if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
             if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
             if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
-                if self.delay_grad_averaging:
-                    # wait for previous averaging to finish before starting a new one
-                    self.state_averager.step(wait_for_delayed_updates=True)
-
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
                 logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
@@ -545,12 +549,13 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def _maybe_schedule_state_averaging(self) -> None:
     def _maybe_schedule_state_averaging(self) -> None:
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
         """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
-        return
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
         if next_epoch % self.average_state_every != 0:
         if next_epoch % self.average_state_every != 0:
             return  # averaging is not performed at this epoch
             return  # averaging is not performed at this epoch
         if self.state_averager.averaging_in_progress:
         if self.state_averager.averaging_in_progress:
             return  # previous run is still in progress
             return  # previous run is still in progress
+        if self.delay_before_state_averaging.num_updates == 0:
+            return  # not enough data to accurately pre-schedule
 
 
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time = self.tracker.estimated_next_update_time
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
         estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample

+ 2 - 3
hivemind/optim/grad_scaler.py

@@ -60,11 +60,11 @@ class GradScaler(TorchGradScaler):
                 return False
                 return False
 
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
-        if self._is_running_global_step:
+        if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
+            # ^-- invoked privately within hivemind optimizer
             with self._lock:
             with self._lock:
                 if self._is_ready_to_update:
                 if self._is_ready_to_update:
                     logger.warning("Please call grad_scaler.update() after each step")
                     logger.warning("Please call grad_scaler.update() after each step")
-                assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
                 assert (
                 assert (
                     self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
                     self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
                 ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
                 ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
@@ -75,7 +75,6 @@ class GradScaler(TorchGradScaler):
                 self._is_ready_to_update = True
                 self._is_ready_to_update = True
                 return True
                 return True
         else:
         else:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             super().step(optimizer)
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
             return False

+ 23 - 9
hivemind/utils/asyncio.py

@@ -114,9 +114,15 @@ async def amap_in_executor(
     queue = asyncio.Queue(max_prefetch)
     queue = asyncio.Queue(max_prefetch)
 
 
     async def _put_items():
     async def _put_items():
-        async for args in azip(*iterables):
-            await queue.put(loop.run_in_executor(executor, func, *args))
-        await queue.put(None)
+        try:
+            async for args in azip(*iterables):
+                await queue.put(loop.run_in_executor(executor, func, *args))
+            await queue.put(None)
+        except BaseException as e:
+            future = asyncio.Future()
+            future.set_exception(e)
+            await queue.put(future)
+            raise
 
 
     task = asyncio.create_task(_put_items())
     task = asyncio.create_task(_put_items())
     try:
     try:
@@ -124,13 +130,21 @@ async def amap_in_executor(
         while future is not None:
         while future is not None:
             yield await future
             yield await future
             future = await queue.get()
             future = await queue.get()
-        await task
     finally:
     finally:
-        if not task.done():
-            task.cancel()
-
-
-async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+        task.cancel()
+        try:
+            await task
+        except asyncio.CancelledError:
+            pass
+        except Exception as e:
+            logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+        while not queue.empty():
+            future = queue.get_nowait()
+            if future is not None:
+                future.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
     """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
     """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
     # based on https://stackoverflow.com/a/50245879
     # based on https://stackoverflow.com/a/50245879
     iterator = iterable.__aiter__()
     iterator = iterable.__aiter__()

+ 213 - 0
tests/test_allreduce_fault_tolerance.py

@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import asyncio
+from enum import Enum, auto
+from typing import AsyncIterator
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.averager import *
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.proto import averaging_pb2
+from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Fault(Enum):
+    NONE = auto()
+    FAIL_BEFORE = auto()
+    FAIL_SENDING = auto()
+    SLOW_SENDING = auto()
+    FAIL_REDUCING = auto()
+    SLOW_REDUCING = auto()
+    CANCEL = auto()
+
+
+class FaultyAverager(hivemind.DecentralizedAverager):
+    def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            if self.fault == Fault.FAIL_BEFORE:
+                raise Exception("Oops, I failed!")
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                allreduce = FaultyAllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    fault=self.fault,
+                    **kwargs,
+                )
+
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+
+class FaultyAllReduceRunner(AllReduceRunner):
+    def __init__(self, *args, fault: Fault, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
+        if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
+            async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
+                yield message
+                if i == 2:
+                    if self.fault == Fault.FAIL_SENDING:
+                        yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                        break
+                    else:
+                        await asyncio.sleep(10)
+
+        elif self.fault == Fault.CANCEL:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        else:
+            async for message in super().rpc_aggregate_part(stream, context):
+                yield message
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+
+        first_part = await anext(parts_aiter)
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
+            weight=self.weight,
+        )
+        if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
+            last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
+            if peer_index == last_reducer_index:
+                if self.fault == Fault.FAIL_SENDING:
+                    raise Exception("Oops, I failed!")
+                else:
+                    await asyncio.sleep(10)
+        async for part in parts_aiter:
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "fault0, fault1",
+    [
+        (Fault.NONE, Fault.FAIL_BEFORE),
+        (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
+        (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
+        (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
+        (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
+        (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
+        (Fault.NONE, Fault.CANCEL),
+    ],
+)
+def test_fault_tolerance(fault0: Fault, fault1: Fault):
+    def _make_tensors():
+        return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
+
+    dht = hivemind.DHT(start=True)
+
+    averagers = []
+    for i in range(5):
+        averager = FaultyAverager(
+            _make_tensors(),
+            hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
+            prefix="test",
+            request_timeout=0.3,
+            min_matchmaking_time=1.0,
+            next_chunk_timeout=0.5,
+            allreduce_timeout=5,
+            part_size_bytes=2 ** 16,
+            client_mode=(i == 1),
+            start=True,
+            fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
+        )
+        averagers.append(averager)
+
+    ref_numerators = [0, 0, 0, 0]
+    ref_denominator = 0
+
+    for averager in averagers:
+        if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
+            with averager.get_tensors() as tensors:
+                for i, tensor in enumerate(tensors):
+                    ref_numerators[i] = ref_numerators[i] + tensor.clone()
+                ref_denominator += 1
+
+    ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
+    flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
+
+    flat_local_tensors = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
+
+    futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
+    for i, averager in enumerate(averagers):
+        if averager.fault == Fault.CANCEL:
+            futures[i].cancel()
+
+    for future in futures[2:]:
+        assert future.result()
+
+    for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
+        with averager.get_tensors() as tensors:
+            flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
+
+        diff_with_reference = abs(flat_ref - flat_tensors)
+
+        if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
+            assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
+            assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
+        elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
+            diff_to_reference = abs(flat_ref - flat_tensors)
+            diff_to_local = abs(prev_local_tensors - flat_tensors)
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+            assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
+        elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
+            pass  # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
+        elif fault0 == Fault.NONE:  # only peer1 in client mode may have failed
+            assert diff_with_reference.max() < 1e-5
+        else:
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+
+    for averager in averagers:
+        averager.shutdown()