|
@@ -2,171 +2,262 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import concurrent.futures._base as base
|
|
|
+from contextlib import nullcontext
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.connection
|
|
|
-import time
|
|
|
-from functools import lru_cache
|
|
|
-from typing import Optional, Tuple, Generic, TypeVar
|
|
|
+import os
|
|
|
+import threading
|
|
|
+import uuid
|
|
|
+from enum import Enum, auto
|
|
|
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable
|
|
|
|
|
|
-from hivemind.utils.threading import run_in_background
|
|
|
+import torch # used for py3.7-compatible shared memory
|
|
|
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
+
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+# flavour types
|
|
|
ResultType = TypeVar('ResultType')
|
|
|
+PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
|
|
|
+ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
|
|
|
+TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
|
|
|
|
|
|
+try:
|
|
|
+ from concurrent.futures import InvalidStateError
|
|
|
+except ImportError:
|
|
|
+ # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
|
|
|
+ # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
|
|
|
+ # defining and raising this error if necessary.
|
|
|
+ class InvalidStateError(Exception):
|
|
|
+ """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
|
|
|
|
|
|
-class FutureStateError(RuntimeError):
|
|
|
- """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
|
|
|
- pass
|
|
|
+
|
|
|
+class UpdateType(Enum):
|
|
|
+ RESULT = auto()
|
|
|
+ EXCEPTION = auto()
|
|
|
+ CANCEL = auto()
|
|
|
|
|
|
|
|
|
class MPFuture(base.Future, Generic[ResultType]):
|
|
|
- """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
|
|
|
+ """
|
|
|
+ A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
|
|
|
+ Any process can access future status and set the result / exception and check for state.
|
|
|
+ However, only the original process (i.e. the process that created the future) can await the result or exception.
|
|
|
+
|
|
|
+ :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
|
|
|
+ If set to False, writing to this future ignores global lock, slightly improving performance, but making user
|
|
|
+ responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
|
|
|
+ :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
|
|
|
+
|
|
|
+ :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
|
|
|
+ More specifically, there are two known limitations:
|
|
|
+ - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
|
|
|
+ - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
|
|
|
+ and only the origin process can call result/exception/cancel.
|
|
|
+ """
|
|
|
+ _initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
|
|
|
+ _update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
|
|
|
+ _global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
|
|
|
+ _pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
|
|
|
+ _active_futures: Optional[Dict[UID, MPFuture]] = None # pending or running futures originated from current process
|
|
|
+ _active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
|
|
|
|
|
|
- TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
|
|
|
+ def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
|
|
|
+ self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
|
|
|
+ self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
|
|
|
+ self._state_cache: Dict[State, State] = {} # mapping from global to cached local future used that makes updates immediately
|
|
|
+ # available on setter side; dictionary-based cache works because future can visit any state at most once
|
|
|
|
|
|
- def __init__(self, connection: mp.connection.Connection):
|
|
|
- """ manually create MPFuture. Please use MPFuture.make_pair instead """
|
|
|
+ base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code
|
|
|
self._state, self._result, self._exception = base.PENDING, None, None
|
|
|
- self.connection = connection
|
|
|
+ self._use_lock = use_lock
|
|
|
|
|
|
- @classmethod
|
|
|
- def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
|
|
|
- """ Create a pair of linked futures to be used in two processes """
|
|
|
- connection1, connection2 = mp.Pipe()
|
|
|
- return cls(connection1), cls(connection2)
|
|
|
+ if self._origin_pid != MPFuture._active_pid:
|
|
|
+ with MPFuture._initialization_lock:
|
|
|
+ if self._origin_pid != MPFuture._active_pid:
|
|
|
+ # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
|
|
|
+ self._initialize_mpfuture_backend()
|
|
|
+ assert self._uid not in MPFuture._active_futures
|
|
|
+ MPFuture._active_futures[self._uid] = self
|
|
|
+ self._sender_pipe = MPFuture._global_sender_pipe
|
|
|
|
|
|
- def _send_updates(self):
|
|
|
- """ Send updates to a paired MPFuture """
|
|
|
try:
|
|
|
- self.connection.send((self._state, self._result, self._exception))
|
|
|
- if self._state in self.TERMINAL_STATES:
|
|
|
- self._shutdown_trigger.set_result(True)
|
|
|
- self.connection.close()
|
|
|
- return True
|
|
|
- except BrokenPipeError:
|
|
|
- return False
|
|
|
+ self._loop = loop or asyncio.get_event_loop()
|
|
|
+ self._aio_event = asyncio.Event()
|
|
|
+ except RuntimeError:
|
|
|
+ self._loop, self._aio_event = None, None
|
|
|
|
|
|
- def _recv_updates(self, timeout: Optional[float]):
|
|
|
- """ Await updates from a paired MPFuture """
|
|
|
- try:
|
|
|
- future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
|
|
|
- return_when=base.FIRST_COMPLETED)[0].pop()
|
|
|
- if future is self._shutdown_trigger:
|
|
|
- raise BrokenPipeError()
|
|
|
- if not future.result():
|
|
|
- raise TimeoutError()
|
|
|
- self._state, result, exception = self.connection.recv()
|
|
|
- self._result = result if result is not None else self._result
|
|
|
- self._exception = exception if exception is not None else self._exception
|
|
|
- if self._state in self.TERMINAL_STATES:
|
|
|
- self.connection.close()
|
|
|
- except TimeoutError as e:
|
|
|
- raise e
|
|
|
- except (BrokenPipeError, OSError, EOFError) as e:
|
|
|
- if self._state in (base.PENDING, base.RUNNING):
|
|
|
- self._state, self._exception = base.FINISHED, e
|
|
|
-
|
|
|
- def _await_terminal_state(self, timeout: Optional[float]):
|
|
|
- """ Await updates until future is either finished, cancelled or got an exception """
|
|
|
- time_left = float('inf') if timeout is None else timeout
|
|
|
- time_before = time.monotonic()
|
|
|
- while self._state not in self.TERMINAL_STATES and time_left > 0:
|
|
|
- self._recv_updates(time_left if timeout else None)
|
|
|
- time_spent = time.monotonic() - time_before
|
|
|
- time_left, time_before = time_left - time_spent, time_before + time_spent
|
|
|
-
|
|
|
- def _sync_updates(self):
|
|
|
- """ Apply queued updates from a paired MPFuture without waiting for new ones """
|
|
|
+ @property
|
|
|
+ def _state(self) -> State:
|
|
|
+ shared_state = ALL_STATES[self._shared_state_code.item()]
|
|
|
+ return self._state_cache.get(shared_state, shared_state)
|
|
|
+
|
|
|
+ @_state.setter
|
|
|
+ def _state(self, new_state: State):
|
|
|
+ self._shared_state_code[...] = ALL_STATES.index(new_state)
|
|
|
+ if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
|
|
|
+ self._set_event_threadsafe()
|
|
|
+
|
|
|
+ def _set_event_threadsafe(self):
|
|
|
try:
|
|
|
- self._recv_updates(timeout=0)
|
|
|
- except TimeoutError:
|
|
|
- pass
|
|
|
+ loop = asyncio.get_running_loop()
|
|
|
+ except RuntimeError:
|
|
|
+ loop = None
|
|
|
+
|
|
|
+ async def _event_setter():
|
|
|
+ self._aio_event.set()
|
|
|
+
|
|
|
+ if loop == self.get_loop():
|
|
|
+ asyncio.create_task(_event_setter())
|
|
|
+ else:
|
|
|
+ asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _initialize_mpfuture_backend(cls):
|
|
|
+ pid = os.getpid()
|
|
|
+ logger.debug(f"Initializing MPFuture backend for pid {pid}")
|
|
|
+ assert pid != cls._active_pid, "already initialized"
|
|
|
+
|
|
|
+ receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
|
|
|
+ cls._active_pid, cls._active_futures = pid, {}
|
|
|
+ cls._pipe_waiter_thread = threading.Thread(target=cls._process_updates_in_background, args=[receiver_pipe],
|
|
|
+ name=f'{__name__}.BACKEND', daemon=True)
|
|
|
+ cls._pipe_waiter_thread.start()
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
|
|
|
+ pid = os.getpid()
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ uid, update_type, payload = receiver_pipe.recv()
|
|
|
+ if uid not in cls._active_futures:
|
|
|
+ logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
|
|
|
+ elif update_type == UpdateType.RESULT:
|
|
|
+ cls._active_futures.pop(uid).set_result(payload)
|
|
|
+ elif update_type == UpdateType.EXCEPTION:
|
|
|
+ cls._active_futures.pop(uid).set_exception(payload)
|
|
|
+ elif update_type == UpdateType.CANCEL:
|
|
|
+ cls._active_futures.pop(uid).cancel()
|
|
|
+ else:
|
|
|
+ raise RuntimeError(f"Received unexpected update type {update_type}")
|
|
|
+ except (BrokenPipeError, EOFError):
|
|
|
+ logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
|
|
|
+
|
|
|
+ def _send_update(self, update_type: UpdateType, payload: Any = None):
|
|
|
+ """ This method sends result, exception or cancel to the MPFuture origin. """
|
|
|
+ with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
+ self._sender_pipe.send((self._uid, update_type, payload))
|
|
|
|
|
|
def set_result(self, result: ResultType):
|
|
|
- self._sync_updates()
|
|
|
- if self._state in self.TERMINAL_STATES:
|
|
|
- raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})")
|
|
|
- self._state, self._result = base.FINISHED, result
|
|
|
- return self._send_updates()
|
|
|
-
|
|
|
- def set_exception(self, exception: BaseException):
|
|
|
- self._sync_updates()
|
|
|
- if self._state in self.TERMINAL_STATES:
|
|
|
- raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})")
|
|
|
- self._state, self._exception = base.FINISHED, exception
|
|
|
- self._send_updates()
|
|
|
+ if os.getpid() == self._origin_pid:
|
|
|
+ super().set_result(result)
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ elif self._state in TERMINAL_STATES:
|
|
|
+ raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
|
|
|
+ else:
|
|
|
+ self._state_cache[self._state], self._result = base.FINISHED, result
|
|
|
+ self._send_update(UpdateType.RESULT, result)
|
|
|
+
|
|
|
+ def set_exception(self, exception: Optional[BaseException]):
|
|
|
+ if os.getpid() == self._origin_pid:
|
|
|
+ super().set_exception(exception)
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ elif self._state in TERMINAL_STATES:
|
|
|
+ raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
|
|
|
+ else:
|
|
|
+ self._state_cache[self._state], self._exception = base.FINISHED, exception
|
|
|
+ self._send_update(UpdateType.EXCEPTION, exception)
|
|
|
+
|
|
|
+ def cancel(self) -> bool:
|
|
|
+ if os.getpid() == self._origin_pid:
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ return super().cancel()
|
|
|
+ elif self._state in [base.RUNNING, base.FINISHED]:
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ self._state_cache[self._state] = base.CANCELLED
|
|
|
+ self._send_update(UpdateType.CANCEL)
|
|
|
+ return True
|
|
|
|
|
|
def set_running_or_notify_cancel(self):
|
|
|
- self._sync_updates()
|
|
|
if self._state == base.PENDING:
|
|
|
self._state = base.RUNNING
|
|
|
- return self._send_updates()
|
|
|
+ return True
|
|
|
elif self._state == base.CANCELLED:
|
|
|
return False
|
|
|
else:
|
|
|
- raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})")
|
|
|
-
|
|
|
- def cancel(self):
|
|
|
- self._sync_updates()
|
|
|
- if self._state in self.TERMINAL_STATES:
|
|
|
- return False
|
|
|
- self._state, self._exception = base.CANCELLED, base.CancelledError()
|
|
|
- return self._send_updates()
|
|
|
+ raise InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})")
|
|
|
|
|
|
def result(self, timeout: Optional[float] = None) -> ResultType:
|
|
|
- self._await_terminal_state(timeout)
|
|
|
- if self._exception is not None:
|
|
|
+ if self._state not in TERMINAL_STATES:
|
|
|
+ if os.getpid() != self._origin_pid:
|
|
|
+ raise RuntimeError("Only the process that created MPFuture can await result")
|
|
|
+ return super().result(timeout)
|
|
|
+ elif self._state == base.CANCELLED:
|
|
|
+ raise base.CancelledError()
|
|
|
+ elif self._exception:
|
|
|
raise self._exception
|
|
|
- return self._result
|
|
|
+ else:
|
|
|
+ return self._result
|
|
|
|
|
|
- def exception(self, timeout=None) -> BaseException:
|
|
|
- self._await_terminal_state(timeout)
|
|
|
- if self._state == base.CANCELLED:
|
|
|
+ def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
|
|
|
+ if self._state not in TERMINAL_STATES:
|
|
|
+ if os.getpid() != self._origin_pid:
|
|
|
+ raise RuntimeError("Only the process that created MPFuture can await exception")
|
|
|
+ return super().exception(timeout)
|
|
|
+ elif self._state == base.CANCELLED:
|
|
|
raise base.CancelledError()
|
|
|
return self._exception
|
|
|
|
|
|
def done(self) -> bool:
|
|
|
- self._sync_updates()
|
|
|
- return self._state in self.TERMINAL_STATES
|
|
|
+ return self._state in TERMINAL_STATES
|
|
|
|
|
|
def running(self):
|
|
|
- self._sync_updates()
|
|
|
return self._state == base.RUNNING
|
|
|
|
|
|
def cancelled(self):
|
|
|
- self._sync_updates()
|
|
|
return self._state == base.CANCELLED
|
|
|
|
|
|
- def add_done_callback(self, callback):
|
|
|
- raise NotImplementedError(f"MPFuture doesn't support callbacks.")
|
|
|
-
|
|
|
- def remove_done_callback(self, callback):
|
|
|
- raise NotImplementedError(f"MPFuture doesn't support callbacks.")
|
|
|
+ def add_done_callback(self, callback: Callable[[MPFuture], None]):
|
|
|
+ if os.getpid() != self._origin_pid:
|
|
|
+ raise RuntimeError("Only the process that created MPFuture can set callbacks")
|
|
|
+ return super().add_done_callback(callback)
|
|
|
|
|
|
- def get_loop(self):
|
|
|
- raise NotImplementedError(f"MPFuture doesn't support get_loop")
|
|
|
-
|
|
|
- @property
|
|
|
- @lru_cache()
|
|
|
- def _shutdown_trigger(self):
|
|
|
- return base.Future()
|
|
|
-
|
|
|
- def __repr__(self):
|
|
|
- self._sync_updates()
|
|
|
- if self._state == base.FINISHED:
|
|
|
- if self._exception:
|
|
|
- return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
|
|
|
- else:
|
|
|
- return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
|
|
|
- else:
|
|
|
- return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
|
|
|
+ def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
|
|
|
+ return self._loop
|
|
|
|
|
|
def __await__(self):
|
|
|
- yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
|
|
|
- if self._exception:
|
|
|
- raise self._exception
|
|
|
- return self._result
|
|
|
+ if not self._aio_event:
|
|
|
+ raise RuntimeError("Can't await: MPFuture was created with no event loop")
|
|
|
+ yield from self._aio_event.wait().__await__()
|
|
|
+ try:
|
|
|
+ return super().result(timeout=0)
|
|
|
+ except base.CancelledError:
|
|
|
+ raise asyncio.CancelledError()
|
|
|
|
|
|
def __del__(self):
|
|
|
- self._shutdown_trigger.set_result(True)
|
|
|
- if hasattr(self, 'connection'):
|
|
|
- self.connection.close()
|
|
|
+ if getattr(self, '_origin_pid', None) == os.getpid():
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ if getattr(self, '_aio_event', None):
|
|
|
+ self._aio_event.set()
|
|
|
+
|
|
|
+ def __getstate__(self):
|
|
|
+ return dict(_sender_pipe=self._sender_pipe, _shared_state_code=self._shared_state_code,
|
|
|
+ _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock,
|
|
|
+ _result=self._result, _exception=self._exception)
|
|
|
+
|
|
|
+ def __setstate__(self, state):
|
|
|
+ self._sender_pipe = state['_sender_pipe']
|
|
|
+ self._shared_state_code = state['_shared_state_code']
|
|
|
+ self._origin_pid, self._uid = state['_origin_pid'], state['_uid']
|
|
|
+ self._result, self._exception = state['_result'], state['_exception']
|
|
|
+ self._use_lock = state['_use_lock']
|
|
|
+
|
|
|
+ self._waiters, self._done_callbacks = [], []
|
|
|
+ self._condition = threading.Condition()
|
|
|
+ self._aio_event, self._loop = None, None
|
|
|
+ self._state_cache = {}
|