asyncio.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import asyncio
  2. import concurrent.futures
  3. from concurrent.futures import ThreadPoolExecutor
  4. from typing import AsyncIterable, AsyncIterator, Awaitable, Callable, Optional, Tuple, TypeVar, Union
  5. import uvloop
  6. from hivemind.utils.logging import get_logger
  7. T = TypeVar("T")
  8. logger = get_logger(__name__)
  9. def switch_to_uvloop() -> asyncio.AbstractEventLoop:
  10. """stop any running event loops; install uvloop; then create, set and return a new event loop"""
  11. try:
  12. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  13. except RuntimeError as error_no_event_loop:
  14. pass # this allows running DHT from background threads with no event loop
  15. uvloop.install()
  16. loop = asyncio.new_event_loop()
  17. asyncio.set_event_loop(loop)
  18. return loop
  19. async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
  20. """equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place!"""
  21. return await aiter.__anext__()
  22. async def as_aiter(*args: T) -> AsyncIterator[T]:
  23. """create an asynchronous iterator from a sequence of values"""
  24. for arg in args:
  25. yield arg
  26. async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
  27. """equivalent of zip for asynchronous iterables"""
  28. iterators = [iterable.__aiter__() for iterable in iterables]
  29. while True:
  30. try:
  31. yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
  32. except StopAsyncIteration:
  33. break
  34. async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
  35. """equivalent to chain(iter1, iter2, ...) for asynchronous iterators."""
  36. for aiter in async_iters:
  37. async for elem in aiter:
  38. yield elem
  39. async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
  40. """equivalent to enumerate(iter) for asynchronous iterators."""
  41. index = 0
  42. async for elem in aiterable:
  43. yield index, elem
  44. index += 1
  45. async def asingle(aiter: AsyncIterable[T]) -> T:
  46. """If ``aiter`` has exactly one item, returns this item. Otherwise, raises ``ValueError``."""
  47. count = 0
  48. async for item in aiter:
  49. count += 1
  50. if count == 2:
  51. raise ValueError("asingle() expected an iterable with exactly one item, but got two or more items")
  52. if count == 0:
  53. raise ValueError("asingle() expected an iterable with exactly one item, but got an empty iterable")
  54. return item
  55. async def afirst(aiter: AsyncIterable[T], default: Optional[T] = None) -> Optional[T]:
  56. """Returns the first item of ``aiter`` or ``default`` if ``aiter`` is empty."""
  57. async for item in aiter:
  58. return item
  59. return default
  60. async def await_cancelled(awaitable: Awaitable) -> bool:
  61. try:
  62. await awaitable
  63. return False
  64. except (asyncio.CancelledError, concurrent.futures.CancelledError):
  65. # In Python 3.7, awaiting a cancelled asyncio.Future raises concurrent.futures.CancelledError
  66. # instead of asyncio.CancelledError
  67. return True
  68. except BaseException:
  69. logger.exception(f"Exception in {awaitable}:")
  70. return False
  71. async def cancel_and_wait(awaitable: Awaitable) -> bool:
  72. """
  73. Cancels ``awaitable`` and waits for its cancellation.
  74. In case of ``asyncio.Task``, helps to avoid ``Task was destroyed but it is pending!`` errors.
  75. In case of ``asyncio.Future``, equal to ``future.cancel()``.
  76. """
  77. awaitable.cancel()
  78. return await await_cancelled(awaitable)
  79. async def amap_in_executor(
  80. func: Callable[..., T],
  81. *iterables: AsyncIterable,
  82. max_prefetch: Optional[int] = None,
  83. executor: Optional[ThreadPoolExecutor] = None,
  84. ) -> AsyncIterator[T]:
  85. """iterate from an async iterable in a background thread, yield results to async iterable"""
  86. loop = asyncio.get_event_loop()
  87. queue = asyncio.Queue(max_prefetch)
  88. async def _put_items():
  89. async for args in azip(*iterables):
  90. await queue.put(loop.run_in_executor(executor, func, *args))
  91. await queue.put(None)
  92. task = asyncio.create_task(_put_items())
  93. try:
  94. future = await queue.get()
  95. while future is not None:
  96. yield await future
  97. future = await queue.get()
  98. await task
  99. finally:
  100. if not task.done():
  101. task.cancel()
  102. async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
  103. """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
  104. # based on https://stackoverflow.com/a/50245879
  105. iterator = iterable.__aiter__()
  106. while True:
  107. try:
  108. yield await asyncio.wait_for(iterator.__anext__(), timeout=timeout)
  109. except StopAsyncIteration:
  110. break