mpfuture.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. from __future__ import annotations
  2. import asyncio
  3. import concurrent.futures._base as base
  4. import multiprocessing as mp
  5. import os
  6. import threading
  7. import uuid
  8. from contextlib import nullcontext
  9. from enum import Enum, auto
  10. from typing import Any, Callable, Dict, Generic, Optional, TypeVar
  11. from weakref import ref
  12. import torch # used for py3.7-compatible shared memory
  13. from hivemind.utils.logging import get_logger
  14. logger = get_logger(__name__)
  15. torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system"))
  16. # flavour types
  17. ResultType = TypeVar("ResultType")
  18. PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
  19. ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
  20. TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
  21. try:
  22. from concurrent.futures import InvalidStateError
  23. except ImportError:
  24. # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
  25. # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
  26. # defining and raising this error if necessary.
  27. class InvalidStateError(Exception):
  28. """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
  29. class SharedBytes:
  30. """
  31. A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
  32. Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
  33. The chunks are deallocated by the garbage collector,
  34. when it detects that all processes no longer use any bytes from this chunk.
  35. """
  36. _lock = mp.Lock()
  37. _pid: Optional[PID] = None
  38. _buffer: Optional[torch.Tensor] = None
  39. _index: int = 0
  40. @classmethod
  41. def next(cls) -> torch.Tensor:
  42. """Create another shared byte value, represented as a scalar uint8 tensor"""
  43. with cls._lock:
  44. if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
  45. buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
  46. cls._pid = os.getpid()
  47. cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
  48. cls._index = 0
  49. cls._index += 1
  50. return cls._buffer[cls._index - 1]
  51. class UpdateType(Enum):
  52. RESULT = auto()
  53. EXCEPTION = auto()
  54. CANCEL = auto()
  55. class MPFuture(base.Future, Generic[ResultType]):
  56. """
  57. A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
  58. Any process can access future status and set the result / exception and check for state.
  59. However, only the original process (i.e. the process that created the future) can await the result or exception.
  60. :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
  61. If set to False, writing to this future ignores global lock, slightly improving performance, but making user
  62. responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
  63. :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
  64. More specifically, there are two known limitations:
  65. - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
  66. - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
  67. and only the origin process can call result/exception/cancel.
  68. """
  69. _initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
  70. _update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
  71. _global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
  72. _pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
  73. _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None # non-done futures originated from this process
  74. _active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
  75. def __init__(self, *, use_lock: bool = True):
  76. self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
  77. self._shared_state_code = SharedBytes.next()
  78. self._state_cache: Dict[State, State] = {}
  79. # mapping from global to cached local future used that makes updates immediately
  80. # available on setter side; dictionary-based cache works because future can visit any state at most once
  81. base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code
  82. self._state, self._result, self._exception = base.PENDING, None, None
  83. self._use_lock = use_lock
  84. if self._origin_pid != MPFuture._active_pid:
  85. with MPFuture._initialization_lock:
  86. if self._origin_pid != MPFuture._active_pid:
  87. # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
  88. self._initialize_mpfuture_backend()
  89. assert self._uid not in MPFuture._active_futures
  90. MPFuture._active_futures[self._uid] = ref(self)
  91. self._sender_pipe = MPFuture._global_sender_pipe
  92. try:
  93. self._loop = asyncio.get_event_loop()
  94. self._aio_event = asyncio.Event()
  95. except RuntimeError:
  96. self._loop, self._aio_event = None, None
  97. @property
  98. def _state(self) -> State:
  99. shared_state = ALL_STATES[self._shared_state_code.item()]
  100. return self._state_cache.get(shared_state, shared_state)
  101. @_state.setter
  102. def _state(self, new_state: State):
  103. self._shared_state_code[...] = ALL_STATES.index(new_state)
  104. if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
  105. self._set_event_threadsafe()
  106. def _set_event_threadsafe(self) -> None:
  107. try:
  108. running_loop = asyncio.get_running_loop()
  109. except RuntimeError:
  110. running_loop = None
  111. async def _event_setter() -> None:
  112. self._aio_event.set()
  113. if self._loop.is_closed():
  114. return # do nothing, the loop is already closed
  115. elif self._loop.is_running() and running_loop == self._loop:
  116. asyncio.create_task(_event_setter())
  117. elif self._loop.is_running() and running_loop != self._loop:
  118. asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
  119. else:
  120. self._loop.run_until_complete(_event_setter())
  121. @classmethod
  122. def _initialize_mpfuture_backend(cls) -> None:
  123. pid = os.getpid()
  124. logger.debug(f"Initializing MPFuture backend for pid {pid}")
  125. receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
  126. cls._active_pid, cls._active_futures = pid, {}
  127. cls._pipe_waiter_thread = threading.Thread(
  128. target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
  129. )
  130. cls._pipe_waiter_thread.start()
  131. @staticmethod
  132. def reset_backend() -> None:
  133. """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
  134. MPFuture._active_pid = None
  135. MPFuture._initialization_lock = mp.Lock()
  136. MPFuture._update_lock = mp.Lock()
  137. SharedBytes._lock = mp.Lock()
  138. @classmethod
  139. def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
  140. pid = os.getpid()
  141. while True:
  142. try:
  143. if cls._pipe_waiter_thread is not threading.current_thread():
  144. break # backend was reset, a new background thread has started
  145. uid, update_type, payload = receiver_pipe.recv()
  146. future = None
  147. future_ref = cls._active_futures.pop(uid, None)
  148. if future_ref is not None:
  149. future = future_ref()
  150. if future is None:
  151. # The MPFuture instance is already destroyed in this process
  152. # (the caller is not interested in the result)
  153. continue
  154. if update_type == UpdateType.RESULT:
  155. future.set_result(payload)
  156. elif update_type == UpdateType.EXCEPTION:
  157. future.set_exception(payload)
  158. elif update_type == UpdateType.CANCEL:
  159. future.cancel()
  160. else:
  161. raise RuntimeError(f"Received unexpected update type {update_type}")
  162. except (BrokenPipeError, EOFError, ConnectionError):
  163. logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
  164. except Exception as e:
  165. logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
  166. def _send_update(self, update_type: UpdateType, payload: Any = None) -> None:
  167. """This method sends result, exception or cancel to the MPFuture origin."""
  168. try:
  169. with MPFuture._update_lock if self._use_lock else nullcontext():
  170. self._sender_pipe.send((self._uid, update_type, payload))
  171. except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
  172. logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
  173. def set_result(self, result: ResultType) -> None:
  174. if os.getpid() == self._origin_pid:
  175. super().set_result(result)
  176. MPFuture._active_futures.pop(self._uid, None)
  177. elif self._state in TERMINAL_STATES:
  178. raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
  179. else:
  180. self._state_cache[self._state], self._result = base.FINISHED, result
  181. self._send_update(UpdateType.RESULT, result)
  182. def set_exception(self, exception: Optional[BaseException]) -> None:
  183. if os.getpid() == self._origin_pid:
  184. super().set_exception(exception)
  185. MPFuture._active_futures.pop(self._uid, None)
  186. elif self._state in TERMINAL_STATES:
  187. raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
  188. else:
  189. self._state_cache[self._state], self._exception = base.FINISHED, exception
  190. self._send_update(UpdateType.EXCEPTION, exception)
  191. def cancel(self) -> bool:
  192. if os.getpid() == self._origin_pid:
  193. MPFuture._active_futures.pop(self._uid, None)
  194. return super().cancel()
  195. elif self._state in [base.RUNNING, base.FINISHED]:
  196. return False
  197. else:
  198. self._state_cache[self._state] = base.CANCELLED
  199. self._send_update(UpdateType.CANCEL)
  200. return True
  201. def set_running_or_notify_cancel(self) -> bool:
  202. if self._state == base.PENDING:
  203. self._state = base.RUNNING
  204. return True
  205. elif self._state == base.CANCELLED:
  206. return False
  207. else:
  208. raise InvalidStateError(
  209. f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
  210. )
  211. def result(self, timeout: Optional[float] = None) -> ResultType:
  212. if self._state not in TERMINAL_STATES:
  213. if os.getpid() != self._origin_pid:
  214. raise RuntimeError("Only the process that created MPFuture can await result")
  215. return super().result(timeout)
  216. elif self._state == base.CANCELLED:
  217. raise base.CancelledError()
  218. elif self._exception:
  219. raise self._exception
  220. else:
  221. return self._result
  222. def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
  223. if self._state not in TERMINAL_STATES:
  224. if os.getpid() != self._origin_pid:
  225. raise RuntimeError("Only the process that created MPFuture can await exception")
  226. return super().exception(timeout)
  227. elif self._state == base.CANCELLED:
  228. raise base.CancelledError()
  229. return self._exception
  230. def done(self) -> bool:
  231. return self._state in TERMINAL_STATES
  232. def running(self) -> bool:
  233. return self._state == base.RUNNING
  234. def cancelled(self) -> bool:
  235. return self._state == base.CANCELLED
  236. def add_done_callback(self, callback: Callable[[MPFuture], None]) -> None:
  237. if os.getpid() != self._origin_pid:
  238. raise RuntimeError("Only the process that created MPFuture can set callbacks")
  239. return super().add_done_callback(callback)
  240. def __await__(self) -> Any:
  241. if not self._aio_event:
  242. raise RuntimeError("Can't await: MPFuture was created with no event loop")
  243. yield from self._aio_event.wait().__await__()
  244. try:
  245. return super().result()
  246. except base.CancelledError:
  247. raise asyncio.CancelledError()
  248. def __del__(self) -> None:
  249. if getattr(self, "_origin_pid", None) == os.getpid():
  250. MPFuture._active_futures.pop(self._uid, None)
  251. if getattr(self, "_aio_event", None):
  252. self._aio_event.set()
  253. def __getstate__(self) -> Dict[str, Any]:
  254. return dict(
  255. _sender_pipe=self._sender_pipe,
  256. _shared_state_code=self._shared_state_code,
  257. _origin_pid=self._origin_pid,
  258. _uid=self._uid,
  259. _use_lock=self._use_lock,
  260. _result=self._result,
  261. _exception=self._exception,
  262. )
  263. def __setstate__(self, state: Dict[str, Any]) -> None:
  264. self._sender_pipe = state["_sender_pipe"]
  265. self._shared_state_code = state["_shared_state_code"]
  266. self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
  267. self._result, self._exception = state["_result"], state["_exception"]
  268. self._use_lock = state["_use_lock"]
  269. self._waiters, self._done_callbacks = [], []
  270. self._condition = threading.Condition()
  271. self._aio_event, self._loop = None, None
  272. self._state_cache = {}