|
@@ -397,7 +397,16 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
try:
|
|
|
self._pending_group_assembled.clear()
|
|
|
step.stage = AveragingStage.LOOKING_FOR_GROUP
|
|
|
- group_info = await self._matchmaking.look_for_group(step)
|
|
|
+ matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
|
|
|
+
|
|
|
+ await asyncio.wait(
|
|
|
+ {matchmaking_task, step.wait_for_trigger()},
|
|
|
+ return_when=asyncio.FIRST_COMPLETED
|
|
|
+ )
|
|
|
+ if step.cancelled():
|
|
|
+ raise asyncio.CancelledError()
|
|
|
+
|
|
|
+ group_info = await matchmaking_task
|
|
|
if group_info is None:
|
|
|
raise AllreduceException("Averaging step failed: could not find a group.")
|
|
|
|
|
@@ -426,9 +435,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
|
|
|
asyncio.InvalidStateError,
|
|
|
P2PHandlerError,
|
|
|
) as e:
|
|
|
- if not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
- logger.exception(e)
|
|
|
- step.set_exception(e)
|
|
|
+ if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
|
|
|
+ if not step.cancelled():
|
|
|
+ logger.exception(e)
|
|
|
+ if not step.done():
|
|
|
+ step.set_exception(e)
|
|
|
else:
|
|
|
logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
|
|
|
|