Selaa lähdekoodia

Fix a potential deadlock in await_asynchronously with nested locks (#503)

This PR fixes a potential deadlock in hivemind.utils.enter_asynchronously.
This deadlock occurs when many coroutines enter nested locks and exhaust all workers in ThreadPoolExecutor.
In this PR, we mitigate it by creating a dedicated executor for entering locks with no limit to the number of workers.

Co-authored-by: Aleksandr Borzunov <borzunov.alexander@gmail.com>
justheuristic 3 vuotta sitten
vanhempi
commit
b02bdadd0e
2 muutettua tiedostoa jossa 42 lisäystä ja 1 poistoa
  1. 16 1
      hivemind/utils/asyncio.py
  2. 26 0
      tests/test_util_modules.py

+ 16 - 1
hivemind/utils/asyncio.py

@@ -1,5 +1,7 @@
 import asyncio
 import concurrent.futures
+import multiprocessing as mp
+import os
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
 from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
@@ -167,12 +169,25 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
 class _AsyncContextWrapper(AbstractAsyncContextManager):
     """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
 
+    EXECUTOR_PID = None
+    CONTEXT_EXECUTOR = None
+    EXECUTOR_LOCK = mp.Lock()
+
     def __init__(self, context: AbstractContextManager):
         self._context = context
 
+    @classmethod
+    def get_process_wide_executor(cls):
+        if os.getpid() != cls.EXECUTOR_PID:
+            with cls.EXECUTOR_LOCK:
+                if os.getpid() != cls.EXECUTOR_PID:
+                    cls.CONTEXT_EXECUTOR = ThreadPoolExecutor(max_workers=float("inf"))
+                    cls.EXECUTOR_PID = os.getpid()
+        return cls.CONTEXT_EXECUTOR
+
     async def __aenter__(self):
         loop = asyncio.get_event_loop()
-        return await loop.run_in_executor(None, self._context.__enter__)
+        return await loop.run_in_executor(self.get_process_wide_executor(), self._context.__enter__)
 
     async def __aexit__(self, exc_type, exc_value, traceback):
         return self._context.__exit__(exc_type, exc_value, traceback)

+ 26 - 0
tests/test_util_modules.py

@@ -507,6 +507,32 @@ async def test_async_context():
     # running this without enter_asynchronously would deadlock the event loop
 
 
+@pytest.mark.asyncio
+async def test_async_context_flooding():
+    """
+    test for a possible deadlock when many coroutines await the lock and overwhelm the underlying ThreadPoolExecutor
+
+    Here's how the test below works: suppose that the thread pool has at most N workers;
+    If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers;
+    When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2);
+    During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A.
+    Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers.
+    Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor.
+
+    """
+    lock1, lock2 = mp.Lock(), mp.Lock()
+
+    async def coro():
+        async with enter_asynchronously(lock1):
+            await asyncio.sleep(1e-2)
+            async with enter_asynchronously(lock2):
+                await asyncio.sleep(1e-2)
+
+    num_coros = max(100, mp.cpu_count() * 5 + 1)
+    # note: if we deprecate py3.7, this can be reduced to max(33, cpu + 5); see https://bugs.python.org/issue35279
+    await asyncio.wait({coro() for _ in range(num_coros)})
+
+
 def test_batch_tensor_descriptor_msgpack():
     tensor_descr = BatchTensorDescriptor.from_tensor(torch.ones(1, 3, 3, 7))
     tensor_descr_roundtrip = MSGPackSerializer.loads(MSGPackSerializer.dumps(tensor_descr))