asyncio.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from concurrent.futures import ThreadPoolExecutor
  2. from typing import TypeVar, AsyncIterator, Union, AsyncIterable, Awaitable, Tuple, Optional, Callable
  3. import asyncio
  4. import uvloop
  5. from hivemind.utils.logging import get_logger
  6. T = TypeVar('T')
  7. logger = get_logger(__name__)
  8. def switch_to_uvloop() -> asyncio.AbstractEventLoop:
  9. """ stop any running event loops; install uvloop; then create, set and return a new event loop """
  10. try:
  11. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  12. except RuntimeError as error_no_event_loop:
  13. pass # this allows running DHT from background threads with no event loop
  14. uvloop.install()
  15. loop = asyncio.new_event_loop()
  16. asyncio.set_event_loop(loop)
  17. return loop
  18. async def anext(aiter: AsyncIterator[T]) -> Union[T, StopAsyncIteration]:
  19. """ equivalent to next(iter) for asynchronous iterators. Modifies aiter in-place! """
  20. return await aiter.__anext__()
  21. async def aiter(*args: T) -> AsyncIterator[T]:
  22. """ create an asynchronous iterator from a sequence of values """
  23. for arg in args:
  24. yield arg
  25. async def azip(*iterables: AsyncIterable[T]) -> AsyncIterator[Tuple[T, ...]]:
  26. """ equivalent of zip for asynchronous iterables """
  27. iterators = [iterable.__aiter__() for iterable in iterables]
  28. while True:
  29. try:
  30. yield tuple(await asyncio.gather(*(itr.__anext__() for itr in iterators)))
  31. except StopAsyncIteration:
  32. break
  33. async def achain(*async_iters: AsyncIterable[T]) -> AsyncIterator[T]:
  34. """ equivalent to chain(iter1, iter2, ...) for asynchronous iterators. """
  35. for aiter in async_iters:
  36. async for elem in aiter:
  37. yield elem
  38. async def aenumerate(aiterable: AsyncIterable[T]) -> AsyncIterable[Tuple[int, T]]:
  39. """ equivalent to enumerate(iter) for asynchronous iterators. """
  40. index = 0
  41. async for elem in aiterable:
  42. yield index, elem
  43. index += 1
  44. async def await_cancelled(awaitable: Awaitable) -> bool:
  45. try:
  46. await awaitable
  47. return False
  48. except asyncio.CancelledError:
  49. return True
  50. except BaseException:
  51. return False
  52. async def amap_in_executor(func: Callable[..., T], *iterables: AsyncIterable, max_prefetch: Optional[int] = None,
  53. executor: Optional[ThreadPoolExecutor] = None) -> AsyncIterator[T]:
  54. """ iterate from an async iterable in a background thread, yield results to async iterable """
  55. loop = asyncio.get_event_loop()
  56. queue = asyncio.Queue(max_prefetch)
  57. async def _put_items():
  58. async for args in azip(*iterables):
  59. await queue.put(loop.run_in_executor(executor, func, *args))
  60. await queue.put(None)
  61. task = asyncio.create_task(_put_items())
  62. try:
  63. future = await queue.get()
  64. while future is not None:
  65. yield await future
  66. future = await queue.get()
  67. await task
  68. finally:
  69. if not task.done():
  70. task.cancel()