mpfuture.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from __future__ import annotations
  2. import time
  3. import multiprocessing as mp
  4. import multiprocessing.connection
  5. import concurrent.futures._base as base
  6. import asyncio
  7. from functools import lru_cache
  8. from typing import Optional, Tuple, Generic, TypeVar
  9. from hivemind.utils.threading import run_in_background
  10. ResultType = TypeVar('ResultType')
  11. class MPFuture(base.Future, Generic[ResultType]):
  12. """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
  13. TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
  14. def __init__(self, connection: mp.connection.Connection):
  15. """ manually create MPFuture. Please use MPFuture.make_pair instead """
  16. self._state, self._result, self._exception = base.PENDING, None, None
  17. self.connection = connection
  18. @classmethod
  19. def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
  20. """ Create a pair of linked futures to be used in two processes """
  21. connection1, connection2 = mp.Pipe()
  22. return cls(connection1), cls(connection2)
  23. def _send_updates(self):
  24. """ Send updates to a paired MPFuture """
  25. try:
  26. self.connection.send((self._state, self._result, self._exception))
  27. if self._state in self.TERMINAL_STATES:
  28. self._shutdown_trigger.set_result(True)
  29. self.connection.close()
  30. return True
  31. except BrokenPipeError:
  32. return False
  33. def _recv_updates(self, timeout: Optional[float]):
  34. """ Await updates from a paired MPFuture """
  35. try:
  36. future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
  37. return_when=base.FIRST_COMPLETED)[0].pop()
  38. if future is self._shutdown_trigger:
  39. raise BrokenPipeError()
  40. if not future.result():
  41. raise TimeoutError()
  42. self._state, result, exception = self.connection.recv()
  43. self._result = result if result is not None else self._result
  44. self._exception = exception if exception is not None else self._exception
  45. if self._state in self.TERMINAL_STATES:
  46. self.connection.close()
  47. except TimeoutError as e:
  48. raise e
  49. except (BrokenPipeError, OSError, EOFError) as e:
  50. if self._state in (base.PENDING, base.RUNNING):
  51. self._state, self._exception = base.FINISHED, e
  52. def _await_terminal_state(self, timeout: Optional[float]):
  53. """ Await updates until future is either finished, cancelled or got an exception """
  54. time_left = float('inf') if timeout is None else timeout
  55. time_before = time.monotonic()
  56. while self._state not in self.TERMINAL_STATES and time_left > 0:
  57. self._recv_updates(time_left if timeout else None)
  58. time_spent = time.monotonic() - time_before
  59. time_left, time_before = time_left - time_spent, time_before + time_spent
  60. def _sync_updates(self):
  61. """ Apply queued updates from a paired MPFuture without waiting for new ones """
  62. try:
  63. self._recv_updates(timeout=0)
  64. except TimeoutError:
  65. pass
  66. def set_result(self, result: ResultType):
  67. self._sync_updates()
  68. if self._state in self.TERMINAL_STATES:
  69. raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
  70. self._state, self._result = base.FINISHED, result
  71. return self._send_updates()
  72. def set_exception(self, exception: BaseException):
  73. self._sync_updates()
  74. if self._state in self.TERMINAL_STATES:
  75. raise RuntimeError(f"Can't set_exception to a future that is in {self._state}")
  76. self._state, self._exception = base.FINISHED, exception
  77. self._send_updates()
  78. def set_running_or_notify_cancel(self):
  79. self._sync_updates()
  80. if self._state == base.PENDING:
  81. self._state = base.RUNNING
  82. return self._send_updates()
  83. elif self._state == base.CANCELLED:
  84. return False
  85. else:
  86. raise RuntimeError(f"Can't set_running_or_notify_cancel to a future that is in {self._state}")
  87. def cancel(self):
  88. self._sync_updates()
  89. if self._state in self.TERMINAL_STATES:
  90. return False
  91. self._state, self._exception = base.CANCELLED, base.CancelledError()
  92. return self._send_updates()
  93. def result(self, timeout: Optional[float] = None) -> ResultType:
  94. self._await_terminal_state(timeout)
  95. if self._exception is not None:
  96. raise self._exception
  97. return self._result
  98. def exception(self, timeout=None) -> BaseException:
  99. self._await_terminal_state(timeout)
  100. if self._state == base.CANCELLED:
  101. raise base.CancelledError()
  102. return self._exception
  103. def done(self) -> bool:
  104. self._sync_updates()
  105. return self._state in self.TERMINAL_STATES
  106. def running(self):
  107. self._sync_updates()
  108. return self._state == base.RUNNING
  109. def cancelled(self):
  110. self._sync_updates()
  111. return self._state == base.CANCELLED
  112. def add_done_callback(self, callback):
  113. raise NotImplementedError(f"MPFuture doesn't support callbacks.")
  114. def remove_done_callback(self, callback):
  115. raise NotImplementedError(f"MPFuture doesn't support callbacks.")
  116. def get_loop(self):
  117. raise NotImplementedError(f"MPFuture doesn't support get_loop")
  118. @property
  119. @lru_cache()
  120. def _shutdown_trigger(self):
  121. return base.Future()
  122. def __repr__(self):
  123. self._sync_updates()
  124. if self._state == base.FINISHED:
  125. if self._exception:
  126. return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
  127. else:
  128. return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
  129. else:
  130. return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
  131. def __await__(self):
  132. yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
  133. if self._exception:
  134. raise self._exception
  135. return self._result
  136. def __del__(self):
  137. self._shutdown_trigger.set_result(True)
  138. if hasattr(self, 'connection'):
  139. self.connection.close()