mpfuture.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. from __future__ import annotations
  2. import asyncio
  3. import concurrent.futures._base as base
  4. import multiprocessing as mp
  5. import multiprocessing.connection
  6. import os
  7. import threading
  8. import uuid
  9. from contextlib import nullcontext
  10. from enum import Enum, auto
  11. from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar
  12. from weakref import ref
  13. import torch # used for py3.7-compatible shared memory
  14. from hivemind.utils.logging import get_logger
  15. logger = get_logger(__name__)
  16. torch.multiprocessing.set_sharing_strategy(os.environ.get("HIVEMIND_MEMORY_SHARING_STRATEGY", "file_system"))
  17. # flavour types
  18. ResultType = TypeVar("ResultType")
  19. PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
  20. ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
  21. TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
  22. try:
  23. from concurrent.futures import InvalidStateError
  24. except ImportError:
  25. # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
  26. # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
  27. # defining and raising this error if necessary.
  28. class InvalidStateError(Exception):
  29. """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
  30. class SharedBytes:
  31. """
  32. A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
  33. Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
  34. The chunks are deallocated by the garbage collector,
  35. when it detects that all processes no longer use any bytes from this chunk.
  36. """
  37. _lock = mp.Lock()
  38. _pid: Optional[PID] = None
  39. _buffer: Optional[torch.Tensor] = None
  40. _index: int = 0
  41. @classmethod
  42. def next(cls) -> torch.Tensor:
  43. """Create another shared byte value, represented as a scalar uint8 tensor"""
  44. with cls._lock:
  45. if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
  46. buffer_size = int(os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 16))
  47. cls._pid = os.getpid()
  48. cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
  49. cls._index = 0
  50. cls._index += 1
  51. return cls._buffer[cls._index - 1]
  52. class UpdateType(Enum):
  53. RESULT = auto()
  54. EXCEPTION = auto()
  55. CANCEL = auto()
  56. class MPFuture(base.Future, Generic[ResultType]):
  57. """
  58. A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
  59. Any process can access future status and set the result / exception and check for state.
  60. However, only the original process (i.e. the process that created the future) can await the result or exception.
  61. :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
  62. If set to False, writing to this future ignores global lock, slightly improving performance, but making user
  63. responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
  64. :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
  65. More specifically, there are two known limitations:
  66. - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
  67. - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
  68. and only the origin process can call result/exception/cancel.
  69. """
  70. _initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
  71. _update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
  72. _global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
  73. _pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
  74. _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None # non-done futures originated from this process
  75. _active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
  76. def __init__(self, *, use_lock: bool = True):
  77. self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
  78. self._shared_state_code = SharedBytes.next()
  79. self._state_cache: Dict[State, State] = {}
  80. # mapping from global to cached local future used that makes updates immediately
  81. # available on setter side; dictionary-based cache works because future can visit any state at most once
  82. base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code
  83. self._state, self._result, self._exception = base.PENDING, None, None
  84. self._use_lock = use_lock
  85. if self._origin_pid != MPFuture._active_pid:
  86. with MPFuture._initialization_lock:
  87. if self._origin_pid != MPFuture._active_pid:
  88. # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
  89. self._initialize_mpfuture_backend()
  90. assert self._uid not in MPFuture._active_futures
  91. MPFuture._active_futures[self._uid] = ref(self)
  92. self._sender_pipe = MPFuture._global_sender_pipe
  93. try:
  94. self._loop = asyncio.get_event_loop()
  95. self._aio_event = asyncio.Event()
  96. except RuntimeError:
  97. self._loop, self._aio_event = None, None
  98. @property
  99. def _state(self) -> State:
  100. shared_state = ALL_STATES[self._shared_state_code.item()]
  101. return self._state_cache.get(shared_state, shared_state)
  102. @_state.setter
  103. def _state(self, new_state: State):
  104. self._shared_state_code[...] = ALL_STATES.index(new_state)
  105. if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
  106. self._set_event_threadsafe()
  107. def _set_event_threadsafe(self):
  108. try:
  109. running_loop = asyncio.get_running_loop()
  110. except RuntimeError:
  111. running_loop = None
  112. async def _event_setter():
  113. self._aio_event.set()
  114. if self._loop.is_closed():
  115. return # do nothing, the loop is already closed
  116. elif self._loop.is_running() and running_loop == self._loop:
  117. asyncio.create_task(_event_setter())
  118. elif self._loop.is_running() and running_loop != self._loop:
  119. asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
  120. else:
  121. self._loop.run_until_complete(_event_setter())
  122. @classmethod
  123. def _initialize_mpfuture_backend(cls):
  124. pid = os.getpid()
  125. logger.debug(f"Initializing MPFuture backend for pid {pid}")
  126. receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
  127. cls._active_pid, cls._active_futures = pid, {}
  128. cls._pipe_waiter_thread = threading.Thread(
  129. target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
  130. )
  131. cls._pipe_waiter_thread.start()
  132. @staticmethod
  133. def reset_backend():
  134. """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
  135. MPFuture._active_pid = None
  136. MPFuture._initialization_lock = mp.Lock()
  137. MPFuture._update_lock = mp.Lock()
  138. SharedBytes._lock = mp.Lock()
  139. @classmethod
  140. def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
  141. pid = os.getpid()
  142. while True:
  143. try:
  144. if cls._pipe_waiter_thread is not threading.current_thread():
  145. break # backend was reset, a new background thread has started
  146. uid, update_type, payload = receiver_pipe.recv()
  147. future = None
  148. future_ref = cls._active_futures.pop(uid, None)
  149. if future_ref is not None:
  150. future = future_ref()
  151. if future is None:
  152. # The MPFuture instance is already destroyed in this process
  153. # (the caller is not interested in the result)
  154. continue
  155. if update_type == UpdateType.RESULT:
  156. future.set_result(payload)
  157. elif update_type == UpdateType.EXCEPTION:
  158. future.set_exception(payload)
  159. elif update_type == UpdateType.CANCEL:
  160. future.cancel()
  161. else:
  162. raise RuntimeError(f"Received unexpected update type {update_type}")
  163. except (BrokenPipeError, EOFError, ConnectionError):
  164. logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
  165. except Exception as e:
  166. logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
  167. def _send_update(self, update_type: UpdateType, payload: Any = None):
  168. """This method sends result, exception or cancel to the MPFuture origin."""
  169. try:
  170. with MPFuture._update_lock if self._use_lock else nullcontext():
  171. self._sender_pipe.send((self._uid, update_type, payload))
  172. except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
  173. logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
  174. def set_result(self, result: ResultType):
  175. if os.getpid() == self._origin_pid:
  176. super().set_result(result)
  177. MPFuture._active_futures.pop(self._uid, None)
  178. elif self._state in TERMINAL_STATES:
  179. raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
  180. else:
  181. self._state_cache[self._state], self._result = base.FINISHED, result
  182. self._send_update(UpdateType.RESULT, result)
  183. def set_exception(self, exception: Optional[BaseException]):
  184. if os.getpid() == self._origin_pid:
  185. super().set_exception(exception)
  186. MPFuture._active_futures.pop(self._uid, None)
  187. elif self._state in TERMINAL_STATES:
  188. raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
  189. else:
  190. self._state_cache[self._state], self._exception = base.FINISHED, exception
  191. self._send_update(UpdateType.EXCEPTION, exception)
  192. def cancel(self) -> bool:
  193. if os.getpid() == self._origin_pid:
  194. MPFuture._active_futures.pop(self._uid, None)
  195. return super().cancel()
  196. elif self._state in [base.RUNNING, base.FINISHED]:
  197. return False
  198. else:
  199. self._state_cache[self._state] = base.CANCELLED
  200. self._send_update(UpdateType.CANCEL)
  201. return True
  202. def set_running_or_notify_cancel(self):
  203. if self._state == base.PENDING:
  204. self._state = base.RUNNING
  205. return True
  206. elif self._state == base.CANCELLED:
  207. return False
  208. else:
  209. raise InvalidStateError(
  210. f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
  211. )
  212. def result(self, timeout: Optional[float] = None) -> ResultType:
  213. if self._state not in TERMINAL_STATES:
  214. if os.getpid() != self._origin_pid:
  215. raise RuntimeError("Only the process that created MPFuture can await result")
  216. return super().result(timeout)
  217. elif self._state == base.CANCELLED:
  218. raise base.CancelledError()
  219. elif self._exception:
  220. raise self._exception
  221. else:
  222. return self._result
  223. def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
  224. if self._state not in TERMINAL_STATES:
  225. if os.getpid() != self._origin_pid:
  226. raise RuntimeError("Only the process that created MPFuture can await exception")
  227. return super().exception(timeout)
  228. elif self._state == base.CANCELLED:
  229. raise base.CancelledError()
  230. return self._exception
  231. def done(self) -> bool:
  232. return self._state in TERMINAL_STATES
  233. def running(self):
  234. return self._state == base.RUNNING
  235. def cancelled(self):
  236. return self._state == base.CANCELLED
  237. def add_done_callback(self, callback: Callable[[MPFuture], None]):
  238. if os.getpid() != self._origin_pid:
  239. raise RuntimeError("Only the process that created MPFuture can set callbacks")
  240. return super().add_done_callback(callback)
  241. def __await__(self):
  242. if not self._aio_event:
  243. raise RuntimeError("Can't await: MPFuture was created with no event loop")
  244. yield from self._aio_event.wait().__await__()
  245. try:
  246. return super().result()
  247. except base.CancelledError:
  248. raise asyncio.CancelledError()
  249. def __del__(self):
  250. if getattr(self, "_origin_pid", None) == os.getpid():
  251. MPFuture._active_futures.pop(self._uid, None)
  252. if getattr(self, "_aio_event", None):
  253. self._aio_event.set()
  254. def __getstate__(self):
  255. return dict(
  256. _sender_pipe=self._sender_pipe,
  257. _shared_state_code=self._shared_state_code,
  258. _origin_pid=self._origin_pid,
  259. _uid=self._uid,
  260. _use_lock=self._use_lock,
  261. _result=self._result,
  262. _exception=self._exception,
  263. )
  264. def __setstate__(self, state):
  265. self._sender_pipe = state["_sender_pipe"]
  266. self._shared_state_code = state["_shared_state_code"]
  267. self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
  268. self._result, self._exception = state["_result"], state["_exception"]
  269. self._use_lock = state["_use_lock"]
  270. self._waiters, self._done_callbacks = [], []
  271. self._condition = threading.Condition()
  272. self._aio_event, self._loop = None, None
  273. self._state_cache = {}