|
@@ -257,21 +257,24 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
|
|
|
# averaging is finished, exit the loop
|
|
|
future.set_result(allreduce_runner.gathered)
|
|
|
|
|
|
- except (AllreduceException, MatchmakingException, AssertionError,
|
|
|
- asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
|
|
|
+ except (AllreduceException, MatchmakingException, AssertionError, StopAsyncIteration, InternalError,
|
|
|
+ asyncio.CancelledError, asyncio.InvalidStateError, grpc.RpcError, grpc.aio.AioRpcError) as e:
|
|
|
time_elapsed = get_dht_time() - start_time
|
|
|
if not allow_retries or (timeout is not None and timeout < time_elapsed):
|
|
|
- logger.warning(f"Averager caught {e}")
|
|
|
- future.set_result(None)
|
|
|
+ logger.exception(f"Averager caught {repr(e)}")
|
|
|
+ future.set_exception(e)
|
|
|
else:
|
|
|
- logger.warning(f"Averager caught {e}, retrying")
|
|
|
+ logger.warning(f"Averager caught {repr(e)}, retrying")
|
|
|
|
|
|
- except Exception as e:
|
|
|
+ except BaseException as e:
|
|
|
future.set_exception(e)
|
|
|
raise
|
|
|
finally:
|
|
|
_ = self._running_groups.pop(group_id, None)
|
|
|
self._pending_group_assembled.set()
|
|
|
+ if not future.done():
|
|
|
+ future.set_exception(RuntimeError("Internal sanity check failed: averager.step left future pending."
|
|
|
+ " Please report this to hivemind issues."))
|
|
|
|
|
|
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
|
|
|
""" Use a group description found by Matchmaking to form AllreduceRunner """
|