Browse Source

Improve Matchmaking finalizers (#357)

This PR resolves some frequent "Task was destroyed but it is pending!" warnings, as well as a number of other finalizer issues.
Alexander Borzunov 4 năm trước cách đây
mục cha
commit
7f296a43c3

+ 29 - 41
hivemind/averaging/matchmaking.py

@@ -15,7 +15,7 @@ from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
-from hivemind.utils.asyncio import anext
+from hivemind.utils.asyncio import anext, cancel_and_wait
 
 logger = get_logger(__name__)
 
@@ -127,10 +127,9 @@ class Matchmaking:
                 raise
 
             finally:
-                if not request_leaders_task.done():
-                    request_leaders_task.cancel()
-                if not self.assembled_group.done():
-                    self.assembled_group.cancel()
+                await cancel_and_wait(request_leaders_task)
+                self.assembled_group.cancel()
+
                 while len(self.current_followers) > 0:
                     await self.follower_was_discarded.wait()
                     self.follower_was_discarded.clear()
@@ -229,7 +228,7 @@ class Matchmaking:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
         except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.error(f"{self} - failed to request potential leader {leader}: {e}")
+            logger.exception(f"{self} - failed to request potential leader {leader}:")
             return None
 
         finally:
@@ -413,10 +412,9 @@ class PotentialLeaders:
             try:
                 yield self
             finally:
-                if not update_queue_task.done():
-                    update_queue_task.cancel()
-                if declare and not declare_averager_task.done():
-                    declare_averager_task.cancel()
+                await cancel_and_wait(update_queue_task)
+                if declare:
+                    await cancel_and_wait(declare_averager_task)
 
                 for field in (
                     self.past_attempts,
@@ -477,37 +475,31 @@ class PotentialLeaders:
         else:
             return min(get_dht_time() + self.averaging_expiration, self.search_end_time)
 
-    async def _update_queue_periodically(self, key_manager: GroupKeyManager):
-        try:
-            DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
-            while get_dht_time() < self.search_end_time:
-                new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
-                self.max_assured_time = max(
-                    self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
-                )
+    async def _update_queue_periodically(self, key_manager: GroupKeyManager) -> None:
+        DISCREPANCY = timed_storage.MAX_DHT_TIME_DISCREPANCY_SECONDS
+        while get_dht_time() < self.search_end_time:
+            new_peers = await key_manager.get_averagers(key_manager.current_key, only_active=True)
+            self.max_assured_time = max(
+                self.max_assured_time, get_dht_time() + self.averaging_expiration - DISCREPANCY
+            )
 
-                self.leader_queue.clear()
-                for peer, peer_expiration_time in new_peers:
-                    if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
-                        continue
-                    self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
-                    self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
+            self.leader_queue.clear()
+            for peer, peer_expiration_time in new_peers:
+                if peer == self.peer_id or (peer, peer_expiration_time) in self.past_attempts:
+                    continue
+                self.leader_queue.store(peer, peer_expiration_time, peer_expiration_time)
+                self.max_assured_time = max(self.max_assured_time, peer_expiration_time - DISCREPANCY)
 
-                self.update_finished.set()
+            self.update_finished.set()
 
-                await asyncio.wait(
-                    {self.running.wait(), self.update_triggered.wait()},
-                    return_when=asyncio.ALL_COMPLETED,
-                    timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
-                )
-                self.update_triggered.clear()
-        except (concurrent.futures.CancelledError, asyncio.CancelledError):
-            return  # note: this is a compatibility layer for python3.7
-        except Exception as e:
-            logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
-            raise
+            await asyncio.wait(
+                {self.running.wait(), self.update_triggered.wait()},
+                return_when=asyncio.ALL_COMPLETED,
+                timeout=self.search_end_time - get_dht_time() if isfinite(self.search_end_time) else None,
+            )
+            self.update_triggered.clear()
 
-    async def _declare_averager_periodically(self, key_manager: GroupKeyManager):
+    async def _declare_averager_periodically(self, key_manager: GroupKeyManager) -> None:
         async with self.lock_declare:
             try:
                 while True:
@@ -521,10 +513,6 @@ class PotentialLeaders:
                     await asyncio.sleep(self.declared_expiration_time - get_dht_time())
                     if self.running.is_set() and len(self.leader_queue) == 0:
                         await key_manager.update_key_on_not_enough_peers()
-            except (concurrent.futures.CancelledError, asyncio.CancelledError):
-                pass  # note: this is a compatibility layer for python3.7
-            except Exception as e:  # note: we catch exceptions here because otherwise they are never printed
-                logger.error(f"{self.peer_id} - caught {type(e)}: {e}")
             finally:
                 if self.declared_group_key is not None:
                     prev_declared_key, prev_expiration_time = self.declared_group_key, self.declared_expiration_time

+ 4 - 11
hivemind/dht/__init__.py

@@ -27,7 +27,7 @@ from hivemind.dht.node import DEFAULT_NUM_WORKERS, DHTNode
 from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
 from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
 from hivemind.p2p import P2P, PeerID
-from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop
+from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, get_logger, switch_to_uvloop
 
 logger = get_logger(__name__)
 
@@ -261,18 +261,11 @@ class DHT(mp.Process):
     async def _run_coroutine(
         self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]], future: MPFuture[ReturnType]
     ):
-        main_task = asyncio.create_task(coro(self, self._node))
-        cancel_task = asyncio.create_task(await_cancelled(future))
         try:
-            await asyncio.wait({main_task, cancel_task}, return_when=asyncio.FIRST_COMPLETED)
-            if future.cancelled():
-                main_task.cancel()
-            else:
-                future.set_result(await main_task)
+            future.set_result(await coro(self, self._node))
         except BaseException as e:
-            logger.exception(f"Caught an exception when running a coroutine: {e}")
-            if not future.done():
-                future.set_exception(e)
+            logger.exception("Caught an exception when running a coroutine:")
+            future.set_exception(e)
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
         if not self._ready.done():

+ 16 - 1
hivemind/utils/asyncio.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 from concurrent.futures import ThreadPoolExecutor
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
 
@@ -81,12 +82,26 @@ async def await_cancelled(awaitable: Awaitable) -> bool:
     try:
         await awaitable
         return False
-    except asyncio.CancelledError:
+    except (asyncio.CancelledError, concurrent.futures.CancelledError):
+        # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
+        # instead of asyncio.CancelledError
         return True
     except BaseException:
+        logger.exception(f"Exception in {awaitable}:")
         return False
 
 
+async def cancel_and_wait(awaitable: Awaitable) -> bool:
+    """
+    Cancels ``awaitable`` and waits for its cancellation.
+    In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
+    In case of ``asyncio.Future``, equal to ``future.cancel()``.
+    """
+
+    awaitable.cancel()
+    return await await_cancelled(awaitable)
+
+
 async def amap_in_executor(
     func: Callable[..., T],
     *iterables: AsyncIterable,

+ 43 - 1
tests/test_util_modules.py

@@ -13,7 +13,17 @@ from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.utils import DHTExpiration, HeapEntry, MSGPackSerializer, ValueWithExpiration
-from hivemind.utils.asyncio import achain, aenumerate, afirst, aiter, amap_in_executor, anext, asingle, azip
+from hivemind.utils.asyncio import (
+    achain,
+    aenumerate,
+    afirst,
+    aiter,
+    amap_in_executor,
+    anext,
+    asingle,
+    azip,
+    cancel_and_wait,
+)
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.utils.mpfuture import InvalidStateError
 
@@ -509,3 +519,35 @@ async def test_asyncio_utils():
     assert await afirst(aiter()) is None
     assert await afirst(aiter(), -1) == -1
     assert await afirst(aiter(1, 2, 3)) == 1
+
+
+@pytest.mark.asyncio
+async def test_cancel_and_wait():
+    finished_gracefully = False
+
+    async def coro_with_finalizer():
+        nonlocal finished_gracefully
+
+        try:
+            await asyncio.Event().wait()
+        except asyncio.CancelledError:
+            await asyncio.sleep(0.05)
+            finished_gracefully = True
+            raise
+
+    task = asyncio.create_task(coro_with_finalizer())
+    await asyncio.sleep(0.05)
+    assert await cancel_and_wait(task)
+    assert finished_gracefully
+
+    async def coro_with_result():
+        return 777
+
+    async def coro_with_error():
+        raise ValueError("error")
+
+    task_with_result = asyncio.create_task(coro_with_result())
+    task_with_error = asyncio.create_task(coro_with_error())
+    await asyncio.sleep(0.05)
+    assert not await cancel_and_wait(task_with_result)
+    assert not await cancel_and_wait(task_with_error)