|
@@ -1,5 +1,7 @@
|
|
|
import asyncio
|
|
|
import concurrent.futures
|
|
|
+import multiprocessing as mp
|
|
|
+import os
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
|
|
|
from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar, Union
|
|
@@ -167,12 +169,25 @@ async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Ev
|
|
|
class _AsyncContextWrapper(AbstractAsyncContextManager):
|
|
|
"""Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
|
|
|
|
|
|
+ EXECUTOR_PID = None
|
|
|
+ CONTEXT_EXECUTOR = None
|
|
|
+ EXECUTOR_LOCK = mp.Lock()
|
|
|
+
|
|
|
def __init__(self, context: AbstractContextManager):
|
|
|
self._context = context
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def get_process_wide_executor(cls):
|
|
|
+ if os.getpid() != cls.EXECUTOR_PID:
|
|
|
+ with cls.EXECUTOR_LOCK:
|
|
|
+ if os.getpid() != cls.EXECUTOR_PID:
|
|
|
+ cls.CONTEXT_EXECUTOR = ThreadPoolExecutor(max_workers=float("inf"))
|
|
|
+ cls.EXECUTOR_PID = os.getpid()
|
|
|
+ return cls.CONTEXT_EXECUTOR
|
|
|
+
|
|
|
async def __aenter__(self):
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- return await loop.run_in_executor(None, self._context.__enter__)
|
|
|
+ return await loop.run_in_executor(self.get_process_wide_executor(), self._context.__enter__)
|
|
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
|
return self._context.__exit__(exc_type, exc_value, traceback)
|