mpfuture.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. from __future__ import annotations
  2. import asyncio
  3. import concurrent.futures._base as base
  4. from contextlib import nullcontext, suppress
  5. import multiprocessing as mp
  6. import os
  7. import threading
  8. import uuid
  9. from selectors import DefaultSelector, EVENT_READ
  10. from weakref import ref
  11. from enum import Enum, auto
  12. from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
  13. from hivemind.utils.logging import get_logger
  14. logger = get_logger(__name__)
  15. # flavour types
  16. ResultType = TypeVar("ResultType")
  17. PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
  18. ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
  19. TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
  20. try:
  21. from concurrent.futures import InvalidStateError
  22. except ImportError:
  23. # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
  24. # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
  25. # defining and raising this error if necessary.
  26. class InvalidStateError(Exception):
  27. """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
  28. class MessageType(Enum):
  29. RESULT = auto()
  30. EXCEPTION = auto()
  31. RUNNING = auto()
  32. CANCEL = auto()
  33. STATE_REQUEST = auto()
  34. STATE_RESPONSE = auto()
  35. class MPFuture(base.Future, Generic[ResultType]):
  36. """
  37. A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
  38. Any process can access future status and set the result / exception and check for state.
  39. However, only the original process (i.e. the process that created the future) can await the result or exception.
  40. :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
  41. Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
  42. :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
  43. If set to False, writing to this future ignores global lock, slightly improving performance, but making user
  44. responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
  45. :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
  46. More specifically, there are two known limitations:
  47. - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
  48. - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
  49. and only the origin process can call result/exception/cancel.
  50. """
  51. _initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
  52. _update_lock = mp.Lock() # global lock that prevents simultaneous writing of results/exceptions through same pipe
  53. _status_lock = mp.Lock() # global lock that prevents simultaneous sending of status updates through same pipe
  54. _process_inner_pipe: Optional[PipeEnd] = None # a pipe that is used to read results and send status updates
  55. _process_outer_pipe: Optional[PipeEnd] = None # a pipe that is used to send results and receive status updates
  56. _pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
  57. _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None # non-done futures originated from this process
  58. _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None # futures to be updated by origin
  59. _active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
  60. SOFT_UPDATE_TIMEOUT = 0.5 # seconds spent awaiting status update before warning is printed
  61. HARD_UPDATE_TIMEOUT = 10.0 # seconds spent awaiting status update before future is automatically cancelled
  62. def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
  63. super().__init__()
  64. self.synchronize = synchronize
  65. self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
  66. self._state, self._result, self._exception = base.PENDING, None, None
  67. self._use_lock = use_lock
  68. self._initialize_backend_if_necessary()
  69. assert self._uid not in MPFuture._active_futures
  70. MPFuture._active_futures[self._uid] = ref(self)
  71. self._pipe_to_origin = MPFuture._process_outer_pipe
  72. try:
  73. self._loop = asyncio.get_event_loop()
  74. self._aio_event = asyncio.Event()
  75. except RuntimeError:
  76. self._loop, self._aio_event = None, None
  77. def _set_event_if_necessary(self):
  78. if self._aio_event is None or self._aio_event.is_set():
  79. return
  80. try:
  81. loop = asyncio.get_running_loop()
  82. except RuntimeError:
  83. loop = None
  84. async def _event_setter():
  85. self._aio_event.set()
  86. if self._loop.is_running() and loop == self.get_loop():
  87. asyncio.create_task(_event_setter())
  88. elif self._loop.is_running() and loop != self.get_loop():
  89. asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
  90. else:
  91. self._loop.run_until_complete(_event_setter())
  92. @classmethod
  93. def _initialize_backend_if_necessary(cls):
  94. pid = os.getpid()
  95. if MPFuture._active_pid != pid:
  96. with MPFuture._initialization_lock:
  97. if MPFuture._active_pid != pid:
  98. # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
  99. logger.debug(f"Initializing MPFuture backend for pid {pid}")
  100. cls._process_inner_pipe, cls._process_outer_pipe = mp.Pipe(duplex=True)
  101. cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
  102. cls._pipe_waiter_thread = threading.Thread(
  103. target=cls._process_updates_in_background,
  104. name=f"{__name__}.BACKEND",
  105. daemon=True,
  106. )
  107. cls._pipe_waiter_thread.start()
  108. @classmethod
  109. def reset_backend(cls):
  110. """
  111. Reset the MPFuture backend. This is useful when the state may have been corrupted
  112. (e.g. killing child processes may leave the locks acquired and the background thread blocked).
  113. This method is neither thread-safe nor process-safe.
  114. """
  115. cls._initialization_lock = mp.Lock()
  116. cls._update_lock = mp.Lock()
  117. cls._status_lock = mp.Lock()
  118. cls._active_pid = None
  119. @classmethod
  120. def _process_updates_in_background(cls):
  121. pid = os.getpid()
  122. with DefaultSelector() as selector:
  123. selector.register(cls._process_inner_pipe, EVENT_READ, data=cls._process_inner_pipe)
  124. selector.register(cls._process_outer_pipe, EVENT_READ, data=cls._process_outer_pipe)
  125. while True:
  126. try:
  127. if cls._pipe_waiter_thread is not threading.current_thread():
  128. break # Backend was reset, a new background thread has started
  129. (key, events), *_ = selector.select()
  130. uid, msg_type, payload = key.fileobj.recv()
  131. future = None
  132. future_ref = cls._active_futures.get(uid)
  133. if future_ref is not None:
  134. future = future_ref()
  135. if msg_type == MessageType.STATE_REQUEST:
  136. future_state = None if future is None else future.__getstate__()
  137. use_lock, return_pipe = payload
  138. with MPFuture._status_lock if use_lock else nullcontext():
  139. return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
  140. elif msg_type == MessageType.STATE_RESPONSE:
  141. future, state_updated_event = cls._status_requests.get(uid, (None, None))
  142. if future is None:
  143. logger.debug("Received a state update for a future that does not await status update.")
  144. else:
  145. if payload is not None:
  146. future.__setstate__(payload)
  147. else:
  148. base.Future.cancel(future)
  149. state_updated_event.set()
  150. elif future is None:
  151. logger.debug(
  152. f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
  153. )
  154. elif msg_type == MessageType.RESULT:
  155. future.set_result(payload)
  156. elif msg_type == MessageType.EXCEPTION:
  157. future.set_exception(payload)
  158. elif msg_type == MessageType.RUNNING:
  159. try:
  160. future.set_running_or_notify_cancel()
  161. except (InvalidStateError, RuntimeError) as e:
  162. logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
  163. elif msg_type == MessageType.CANCEL:
  164. future.cancel()
  165. else:
  166. raise RuntimeError(f"Received unexpected update type {msg_type}")
  167. if future is None or future.done():
  168. cls._active_futures.pop(uid, None)
  169. except (BrokenPipeError, EOFError, ConnectionError):
  170. logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
  171. except Exception as e:
  172. logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
  173. def _send_update(self, update_type: MessageType, payload: Any = None):
  174. """This method sends result, exception or cancel to the MPFuture origin."""
  175. try:
  176. with MPFuture._update_lock if self._use_lock else nullcontext():
  177. self._pipe_to_origin.send((self._uid, update_type, payload))
  178. except (ConnectionError, BrokenPipeError, EOFError) as e:
  179. logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
  180. def _synchronize_if_necessary(self):
  181. if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
  182. return
  183. self._initialize_backend_if_necessary()
  184. status_updated = threading.Event()
  185. _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
  186. # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
  187. if existing_status_event != status_updated:
  188. existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
  189. return
  190. # otherwise create a new request for synchronization
  191. try:
  192. with MPFuture._update_lock if self._use_lock else nullcontext():
  193. payload = (self._use_lock, self._process_inner_pipe)
  194. self._pipe_to_origin.send((self._uid, MessageType.STATE_REQUEST, payload))
  195. status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
  196. if not status_updated.is_set():
  197. logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
  198. status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
  199. if not status_updated.is_set() and not self.cancel():
  200. with suppress(InvalidStateError, RuntimeError):
  201. self.set_exception(
  202. TimeoutError(
  203. f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
  204. f"MPFuture is cancelled"
  205. )
  206. )
  207. status_updated.set() # this triggers any concurrent _synchronize_if_necessary calls to finish
  208. except (ConnectionError, BrokenPipeError, EOFError) as e:
  209. logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
  210. if not self.cancel():
  211. with suppress(InvalidStateError, RuntimeError):
  212. self.set_exception(e)
  213. finally:
  214. self._status_requests.pop(self._uid, None)
  215. def set_result(self, result: ResultType):
  216. if self._state in TERMINAL_STATES:
  217. raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
  218. elif os.getpid() == self._origin_pid:
  219. MPFuture._active_futures.pop(self._uid, None)
  220. self._set_event_if_necessary()
  221. else:
  222. self._send_update(MessageType.RESULT, result)
  223. super().set_result(result)
  224. def set_exception(self, exception: Optional[BaseException]):
  225. if self._state in TERMINAL_STATES:
  226. raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
  227. elif os.getpid() == self._origin_pid:
  228. MPFuture._active_futures.pop(self._uid, None)
  229. self._set_event_if_necessary()
  230. else:
  231. self._send_update(MessageType.EXCEPTION, exception)
  232. super().set_exception(exception)
  233. def cancel(self) -> bool:
  234. if self._state in [base.RUNNING, base.FINISHED]:
  235. return False
  236. elif os.getpid() == self._origin_pid:
  237. MPFuture._active_futures.pop(self._uid, None)
  238. self._set_event_if_necessary()
  239. else:
  240. self._send_update(MessageType.CANCEL)
  241. return super().cancel()
  242. def set_running_or_notify_cancel(self):
  243. """if synchronize is set to False, this future will ignore any state changes from origin"""
  244. self._synchronize_if_necessary()
  245. try:
  246. is_running = super().set_running_or_notify_cancel()
  247. if is_running and os.getpid() != self._origin_pid:
  248. self._send_update(MessageType.RUNNING)
  249. return is_running
  250. except RuntimeError as e:
  251. raise InvalidStateError(str(e))
  252. def result(self, timeout: Optional[float] = None) -> ResultType:
  253. if self._state not in TERMINAL_STATES:
  254. if os.getpid() != self._origin_pid:
  255. raise RuntimeError("Only the process that created MPFuture can await result")
  256. return super().result(timeout)
  257. def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
  258. if self._state not in TERMINAL_STATES:
  259. if os.getpid() != self._origin_pid:
  260. raise RuntimeError("Only the process that created MPFuture can await exception")
  261. return super().exception(timeout)
  262. def done(self) -> bool:
  263. self._synchronize_if_necessary()
  264. return self._state in TERMINAL_STATES
  265. def running(self):
  266. self._synchronize_if_necessary()
  267. return self._state == base.RUNNING
  268. def cancelled(self):
  269. self._synchronize_if_necessary()
  270. return self._state == base.CANCELLED
  271. def add_done_callback(self, callback: Callable[[MPFuture], None]):
  272. if os.getpid() != self._origin_pid:
  273. raise RuntimeError("Only the process that created MPFuture can set callbacks")
  274. return super().add_done_callback(callback)
  275. def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
  276. return self._loop
  277. def __await__(self):
  278. if not self._aio_event:
  279. raise RuntimeError("Can't await: MPFuture was created with no event loop")
  280. yield from self._aio_event.wait().__await__()
  281. try:
  282. return super().result()
  283. except base.CancelledError:
  284. raise asyncio.CancelledError()
  285. def __del__(self):
  286. if getattr(self, "_origin_pid", None) == os.getpid():
  287. MPFuture._active_futures.pop(self._uid, None)
  288. if getattr(self, "_aio_event", None):
  289. self._aio_event.set()
  290. def __getstate__(self):
  291. return dict(
  292. synchronize=self.synchronize,
  293. _pipe_to_origin=self._pipe_to_origin,
  294. _state=self._state,
  295. _origin_pid=self._origin_pid,
  296. _uid=self._uid,
  297. _use_lock=self._use_lock,
  298. _result=self._result,
  299. _exception=self._exception,
  300. )
  301. def __setstate__(self, state):
  302. self.synchronize = state["synchronize"]
  303. self._pipe_to_origin = state["_pipe_to_origin"]
  304. self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
  305. self._result, self._exception = state["_result"], state["_exception"]
  306. self._use_lock = state["_use_lock"]
  307. self._waiters, self._done_callbacks = [], []
  308. self._condition = threading.Condition()
  309. self._aio_event, self._loop = None, None