asyncio.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. import concurrent.futures
  3. from concurrent.futures import ThreadPoolExecutor
  4. from contextlib import AbstractAsyncContextManager, AbstractContextManager, asynccontextmanager
  5. from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, ContextManager, Optional, Tuple, TypeVar, Union
  6. import uvloop
  7. from hivemind.utils.logging import get_logger
  8. T = TypeVar("T")
  9. logger = get_logger(__name__)
  10. def switch_to_uvloop() -> asyncio.AbstractEventLoop:
  11. """stop any running event loops; install uvloop; then create, set and return a new event loop"""
  12. try:
  13. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  14. except RuntimeError as error_no_event_loop:
  15. pass # this allows running DHT from background threads with no event loop
  16. uvloop.install()
  17. loop = asyncio.new_event_loop()
  18. asyncio.set_event_loop(loop)
  19. return loop
  20. async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
  21. """equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place!"""
  22. return await aiter.__anext__()
  23. async def as_aiter(*args: T) -> AsyncIterator[T]:
  24. """create an asynchronous iterator from a sequence of values"""
  25. for arg in args:
  26. yield arg
  27. async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
  28. """equivalent of zip for asynchronous iterables"""
  29. iterators = [iterable.__aiter__() for iterable in iterables]
  30. while True:
  31. try:
  32. yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
  33. except StopAsyncIteration:
  34. break
  35. async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
  36. """equivalent to chain(iter1, iter2, ...) for asynchronous iterators."""
  37. for aiter in async_iters:
  38. async for elem in aiter:
  39. yield elem
  40. async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
  41. """equivalent to enumerate(iter) for asynchronous iterators."""
  42. index = 0
  43. async for elem in aiterable:
  44. yield index, elem
  45. index += 1
  46. async def asingle(aiter: AsyncIterable[T]) -> T:
  47. """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
  48. count = 0
  49. async for item in aiter:
  50. count += 1
  51. if count == 2:
  52. raise ValueError("asingle() expected an iterable with exactly one item, but got two or more items")
  53. if count == 0:
  54. raise ValueError("asingle() expected an iterable with exactly one item, but got an empty iterable")
  55. return item
  56. async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
  57. """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
  58. async for item in aiter:
  59. return item
  60. return default
  61. async def await_cancelled(awaitable: Awaitable) -> bool:
  62. try:
  63. await awaitable
  64. return False
  65. except (asyncio.CancelledError, concurrent.futures.CancelledError):
  66. # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
  67. # instead of asyncio.CancelledError
  68. return True
  69. except BaseException:
  70. logger.exception(f"Exception in {awaitable}:")
  71. return False
  72. async def cancel_and_wait(awaitable: Awaitable) -> bool:
  73. """
  74. Cancels ``awaitable`` and waits for its cancellation.
  75. In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
  76. In case of ``asyncio.Future``, equal to ``future.cancel()``.
  77. """
  78. awaitable.cancel()
  79. return await await_cancelled(awaitable)
  80. async def amap_in_executor(
  81. func: Callable[..., T],
  82. *iterables: AsyncIterable,
  83. max_prefetch: Optional[int] = None,
  84. executor: Optional[ThreadPoolExecutor] = None,
  85. ) -> AsyncIterator[T]:
  86. """iterate from an async iterable in a background thread, yield results to async iterable"""
  87. loop = asyncio.get_event_loop()
  88. queue = asyncio.Queue(max_prefetch)
  89. async def _put_items():
  90. try:
  91. async for args in azip(*iterables):
  92. await queue.put(loop.run_in_executor(executor, func, *args))
  93. await queue.put(None)
  94. except Exception as e:
  95. future = asyncio.Future()
  96. future.set_exception(e)
  97. await queue.put(future)
  98. raise
  99. task = asyncio.create_task(_put_items())
  100. try:
  101. future = await queue.get()
  102. while future is not None:
  103. yield await future
  104. future = await queue.get()
  105. finally:
  106. task.cancel()
  107. try:
  108. await task
  109. except asyncio.CancelledError:
  110. pass
  111. except Exception as e:
  112. logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
  113. while not queue.empty():
  114. future = queue.get_nowait()
  115. if future is not None:
  116. future.cancel()
  117. async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
  118. """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
  119. # based on https://stackoverflow.com/a/50245879
  120. iterator = iterable.__aiter__()
  121. while True:
  122. try:
  123. yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
  124. except StopAsyncIteration:
  125. break
  126. async def attach_event_on_finished(iterable: AsyncIterable[T], event: asyncio.Event()) -> AsyncIterator[T]:
  127. """Iterate over an async iterable and set an event when the iteration has stopped, failed or terminated"""
  128. try:
  129. async for item in iterable:
  130. yield item
  131. finally:
  132. event.set()
  133. class _AsyncContextWrapper(AbstractAsyncContextManager):
  134. """Wrapper for a non-async context manager that allows entering and exiting it in EventLoop-friendly manner"""
  135. def __init__(self, context: AbstractContextManager):
  136. self._context = context
  137. async def __aenter__(self):
  138. loop = asyncio.get_event_loop()
  139. return await loop.run_in_executor(None, self._context.__enter__)
  140. async def __aexit__(self, exc_type, exc_value, traceback):
  141. return self._context.__exit__(exc_type, exc_value, traceback)
  142. @asynccontextmanager
  143. async def enter_asynchronously(context: AbstractContextManager):
  144. """Wrap a non-async context so that it can be entered asynchronously"""
  145. async with _AsyncContextWrapper(context) as ret_value:
  146. yield ret_value