浏览代码

cancellation

justheuristic 3 年之前
父节点
当前提交
ebc452ffd8
共有 1 个文件被更改,包括 15 次插入4 次删除
  1. 15 4
      hivemind/averaging/averager.py

+ 15 - 4
hivemind/averaging/averager.py

@@ -397,7 +397,16 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 try:
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    group_info = await self._matchmaking.look_for_group(step)
+                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+
+                    await asyncio.wait(
+                        {matchmaking_task, step.wait_for_trigger()},
+                        return_when=asyncio.FIRST_COMPLETED
+                    )
+                    if step.cancelled():
+                        raise asyncio.CancelledError()
+
+                    group_info = await matchmaking_task
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
@@ -426,9 +435,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.InvalidStateError,
                     P2PHandlerError,
                 ) as e:
-                    if not step.allow_retries or get_dht_time() >= step.deadline:
-                        logger.exception(e)
-                        step.set_exception(e)
+                    if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
+                        if not step.cancelled():
+                            logger.exception(e)
+                        if not step.done():
+                            step.set_exception(e)
                     else:
                         logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")