|
@@ -2,16 +2,15 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import concurrent.futures._base as base
|
|
|
-from contextlib import nullcontext
|
|
|
+from contextlib import nullcontext, suppress
|
|
|
import multiprocessing as mp
|
|
|
import multiprocessing.connection
|
|
|
import os
|
|
|
import threading
|
|
|
import uuid
|
|
|
+from weakref import ref
|
|
|
from enum import Enum, auto
|
|
|
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable
|
|
|
-
|
|
|
-import torch # used for py3.7-compatible shared memory
|
|
|
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
|
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
@@ -34,10 +33,13 @@ except ImportError:
|
|
|
"""Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
|
|
|
|
|
|
|
|
|
-class UpdateType(Enum):
|
|
|
+class MessageType(Enum):
|
|
|
RESULT = auto()
|
|
|
EXCEPTION = auto()
|
|
|
+ RUNNING = auto()
|
|
|
CANCEL = auto()
|
|
|
+ STATE_REQUEST = auto()
|
|
|
+ STATE_RESPONSE = auto()
|
|
|
|
|
|
|
|
|
class MPFuture(base.Future, Generic[ResultType]):
|
|
@@ -46,6 +48,8 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
Any process can access future status and set the result / exception and check for state.
|
|
|
However, only the original process (i.e. the process that created the future) can await the result or exception.
|
|
|
|
|
|
+ :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
|
|
|
+ Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
|
|
|
:param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
|
|
|
If set to False, writing to this future ignores global lock, slightly improving performance, but making user
|
|
|
responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
|
|
@@ -60,49 +64,36 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
|
|
|
_initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
|
|
|
_update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
|
|
|
- _global_sender_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
|
|
|
+ _process_wide_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
|
|
|
_pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
|
|
|
- _active_futures: Optional[Dict[UID, MPFuture]] = None # pending or running futures originated from current process
|
|
|
+ _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None # non-done futures originated from this process
|
|
|
+ _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None # futures to be updated by origin
|
|
|
_active_pid: Optional[PID] = None # pid of currently active process; used to handle forks natively
|
|
|
|
|
|
- def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
|
|
|
- self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
|
|
|
- self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
|
|
|
- self._state_cache: Dict[State, State] = {}
|
|
|
- # mapping from global to cached local future used that makes updates immediately
|
|
|
- # available on setter side; dictionary-based cache works because future can visit any state at most once
|
|
|
+ SOFT_UPDATE_TIMEOUT = 0.1 # seconds spent awaiting status update before warning is printed
|
|
|
+ HARD_UPDATE_TIMEOUT = 10.0 # seconds spent awaiting status update before future is automatically cancelled
|
|
|
|
|
|
- base.Future.__init__(self) # parent init is deferred because it uses self._shared_state_code
|
|
|
+ def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
|
|
|
+ super().__init__()
|
|
|
+ self.synchronize = synchronize
|
|
|
+ self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
|
|
|
self._state, self._result, self._exception = base.PENDING, None, None
|
|
|
self._use_lock = use_lock
|
|
|
|
|
|
- if self._origin_pid != MPFuture._active_pid:
|
|
|
- with MPFuture._initialization_lock:
|
|
|
- if self._origin_pid != MPFuture._active_pid:
|
|
|
- # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
|
|
|
- self._initialize_mpfuture_backend()
|
|
|
+ self._initialize_backend_if_necessary()
|
|
|
assert self._uid not in MPFuture._active_futures
|
|
|
- MPFuture._active_futures[self._uid] = self
|
|
|
- self._sender_pipe = MPFuture._global_sender_pipe
|
|
|
+ MPFuture._active_futures[self._uid] = ref(self)
|
|
|
+ self._sender_pipe = MPFuture._process_wide_pipe
|
|
|
|
|
|
try:
|
|
|
- self._loop = loop or asyncio.get_event_loop()
|
|
|
+ self._loop = asyncio.get_event_loop()
|
|
|
self._aio_event = asyncio.Event()
|
|
|
except RuntimeError:
|
|
|
self._loop, self._aio_event = None, None
|
|
|
|
|
|
- @property
|
|
|
- def _state(self) -> State:
|
|
|
- shared_state = ALL_STATES[self._shared_state_code.item()]
|
|
|
- return self._state_cache.get(shared_state, shared_state)
|
|
|
-
|
|
|
- @_state.setter
|
|
|
- def _state(self, new_state: State):
|
|
|
- self._shared_state_code[...] = ALL_STATES.index(new_state)
|
|
|
- if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
|
|
|
- self._set_event_threadsafe()
|
|
|
-
|
|
|
- def _set_event_threadsafe(self):
|
|
|
+ def _set_event_if_necessary(self):
|
|
|
+ if self._aio_event is None or self._aio_event.is_set():
|
|
|
+ return
|
|
|
try:
|
|
|
loop = asyncio.get_running_loop()
|
|
|
except RuntimeError:
|
|
@@ -111,120 +102,197 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
async def _event_setter():
|
|
|
self._aio_event.set()
|
|
|
|
|
|
- if loop == self.get_loop():
|
|
|
+ if self._loop.is_running() and loop == self.get_loop():
|
|
|
asyncio.create_task(_event_setter())
|
|
|
- else:
|
|
|
+ elif self._loop.is_running() and loop != self.get_loop():
|
|
|
asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
|
|
|
+ else:
|
|
|
+ self._loop.run_until_complete(_event_setter())
|
|
|
|
|
|
@classmethod
|
|
|
- def _initialize_mpfuture_backend(cls):
|
|
|
+ def _initialize_backend_if_necessary(cls):
|
|
|
pid = os.getpid()
|
|
|
- logger.debug(f"Initializing MPFuture backend for pid {pid}")
|
|
|
- assert pid != cls._active_pid, "already initialized"
|
|
|
-
|
|
|
- receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
|
|
|
- cls._active_pid, cls._active_futures = pid, {}
|
|
|
- cls._pipe_waiter_thread = threading.Thread(
|
|
|
- target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
|
|
|
- )
|
|
|
- cls._pipe_waiter_thread.start()
|
|
|
+ if MPFuture._active_pid != pid:
|
|
|
+ with MPFuture._initialization_lock:
|
|
|
+ if MPFuture._active_pid != pid:
|
|
|
+ # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
|
|
|
+ logger.debug(f"Initializing MPFuture backend for pid {pid}")
|
|
|
+ receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
|
|
|
+ cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
|
|
|
+ cls._pipe_waiter_thread = threading.Thread(
|
|
|
+ target=cls._process_updates_in_background,
|
|
|
+ args=[receiver_pipe],
|
|
|
+ name=f"{__name__}.BACKEND",
|
|
|
+ daemon=True,
|
|
|
+ )
|
|
|
+ cls._pipe_waiter_thread.start()
|
|
|
|
|
|
@classmethod
|
|
|
def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
|
|
|
pid = os.getpid()
|
|
|
while True:
|
|
|
try:
|
|
|
- uid, update_type, payload = receiver_pipe.recv()
|
|
|
- if uid not in cls._active_futures:
|
|
|
- logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
|
|
|
- elif update_type == UpdateType.RESULT:
|
|
|
- cls._active_futures.pop(uid).set_result(payload)
|
|
|
- elif update_type == UpdateType.EXCEPTION:
|
|
|
- cls._active_futures.pop(uid).set_exception(payload)
|
|
|
- elif update_type == UpdateType.CANCEL:
|
|
|
- cls._active_futures.pop(uid).cancel()
|
|
|
+ uid, msg_type, payload = receiver_pipe.recv()
|
|
|
+ future = None
|
|
|
+ future_ref = cls._active_futures.get(uid)
|
|
|
+ if future_ref is not None:
|
|
|
+ future = future_ref()
|
|
|
+
|
|
|
+ if msg_type == MessageType.STATE_REQUEST:
|
|
|
+ future_state = None if future is None else future.__getstate__()
|
|
|
+ use_lock, return_pipe = payload
|
|
|
+ with MPFuture._update_lock if use_lock else nullcontext():
|
|
|
+ return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
|
|
|
+
|
|
|
+ elif msg_type == MessageType.STATE_RESPONSE:
|
|
|
+ future, state_updated_event = cls._status_requests.get(uid, (None, None))
|
|
|
+ if future is None:
|
|
|
+ logger.debug("Received a state update for a future that does not await status update.")
|
|
|
+ else:
|
|
|
+ if payload is not None:
|
|
|
+ future.__setstate__(payload)
|
|
|
+ else:
|
|
|
+ base.Future.cancel(future)
|
|
|
+ state_updated_event.set()
|
|
|
+
|
|
|
+ elif future is None:
|
|
|
+ logger.debug(
|
|
|
+ f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
|
|
|
+ )
|
|
|
+ elif msg_type == MessageType.RESULT:
|
|
|
+ future.set_result(payload)
|
|
|
+ elif msg_type == MessageType.EXCEPTION:
|
|
|
+ future.set_exception(payload)
|
|
|
+ elif msg_type == MessageType.RUNNING:
|
|
|
+ try:
|
|
|
+ future.set_running_or_notify_cancel()
|
|
|
+ except (InvalidStateError, RuntimeError) as e:
|
|
|
+ logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
|
|
|
+ elif msg_type == MessageType.CANCEL:
|
|
|
+ future.cancel()
|
|
|
else:
|
|
|
- raise RuntimeError(f"Received unexpected update type {update_type}")
|
|
|
- except (BrokenPipeError, EOFError):
|
|
|
+ raise RuntimeError(f"Received unexpected update type {msg_type}")
|
|
|
+
|
|
|
+ if future is None or future.done():
|
|
|
+ cls._active_futures.pop(uid, None)
|
|
|
+
|
|
|
+ except (BrokenPipeError, EOFError, ConnectionError):
|
|
|
logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
|
|
|
except Exception as e:
|
|
|
logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
|
|
|
|
|
|
- def _send_update(self, update_type: UpdateType, payload: Any = None):
|
|
|
+ def _send_update(self, update_type: MessageType, payload: Any = None):
|
|
|
"""This method sends result, exception or cancel to the MPFuture origin."""
|
|
|
- with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
- self._sender_pipe.send((self._uid, update_type, payload))
|
|
|
+ try:
|
|
|
+ with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
+ self._sender_pipe.send((self._uid, update_type, payload))
|
|
|
+ except (ConnectionError, BrokenPipeError, EOFError) as e:
|
|
|
+ logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
|
|
|
+
|
|
|
+ def _synchronize_if_necessary(self):
|
|
|
+ if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
|
|
|
+ return
|
|
|
+
|
|
|
+ self._initialize_backend_if_necessary()
|
|
|
+
|
|
|
+ status_updated = threading.Event()
|
|
|
+ _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
|
|
|
+ # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
|
|
|
+
|
|
|
+ if existing_status_event != status_updated:
|
|
|
+ existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
|
|
|
+ return
|
|
|
+
|
|
|
+ # otherwise create a new request for synchronization
|
|
|
+
|
|
|
+ try:
|
|
|
+ with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
+ payload = (self._use_lock, self._process_wide_pipe)
|
|
|
+ self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
|
|
|
+ status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
|
|
|
+ if not status_updated.is_set():
|
|
|
+ logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
|
|
|
+ status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
|
|
|
+ if not status_updated.is_set() and not self.cancel():
|
|
|
+ with suppress(InvalidStateError, RuntimeError):
|
|
|
+ self.set_exception(
|
|
|
+ TimeoutError(
|
|
|
+ f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
|
|
|
+ f"MPFuture is cancelled"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ status_updated.set() # this triggers any concurrent _synchronize_if_necessary calls to finish
|
|
|
+ except (ConnectionError, BrokenPipeError, EOFError) as e:
|
|
|
+ logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
|
|
|
+ if not self.cancel():
|
|
|
+ with suppress(InvalidStateError, RuntimeError):
|
|
|
+ self.set_exception(e)
|
|
|
+ finally:
|
|
|
+ self._status_requests.pop(self._uid, None)
|
|
|
|
|
|
def set_result(self, result: ResultType):
|
|
|
- if os.getpid() == self._origin_pid:
|
|
|
- super().set_result(result)
|
|
|
- MPFuture._active_futures.pop(self._uid, None)
|
|
|
- elif self._state in TERMINAL_STATES:
|
|
|
+ if self._state in TERMINAL_STATES:
|
|
|
raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
|
|
|
+ elif os.getpid() == self._origin_pid:
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ self._set_event_if_necessary()
|
|
|
else:
|
|
|
- self._state_cache[self._state], self._result = base.FINISHED, result
|
|
|
- self._send_update(UpdateType.RESULT, result)
|
|
|
+ self._send_update(MessageType.RESULT, result)
|
|
|
+ super().set_result(result)
|
|
|
|
|
|
def set_exception(self, exception: Optional[BaseException]):
|
|
|
- if os.getpid() == self._origin_pid:
|
|
|
- super().set_exception(exception)
|
|
|
- MPFuture._active_futures.pop(self._uid, None)
|
|
|
- elif self._state in TERMINAL_STATES:
|
|
|
+ if self._state in TERMINAL_STATES:
|
|
|
raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
|
|
|
+ elif os.getpid() == self._origin_pid:
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ self._set_event_if_necessary()
|
|
|
else:
|
|
|
- self._state_cache[self._state], self._exception = base.FINISHED, exception
|
|
|
- self._send_update(UpdateType.EXCEPTION, exception)
|
|
|
+ self._send_update(MessageType.EXCEPTION, exception)
|
|
|
+ super().set_exception(exception)
|
|
|
|
|
|
def cancel(self) -> bool:
|
|
|
- if os.getpid() == self._origin_pid:
|
|
|
- MPFuture._active_futures.pop(self._uid, None)
|
|
|
- return super().cancel()
|
|
|
- elif self._state in [base.RUNNING, base.FINISHED]:
|
|
|
+ if self._state in [base.RUNNING, base.FINISHED]:
|
|
|
return False
|
|
|
+ elif os.getpid() == self._origin_pid:
|
|
|
+ MPFuture._active_futures.pop(self._uid, None)
|
|
|
+ self._set_event_if_necessary()
|
|
|
else:
|
|
|
- self._state_cache[self._state] = base.CANCELLED
|
|
|
- self._send_update(UpdateType.CANCEL)
|
|
|
- return True
|
|
|
+ self._send_update(MessageType.CANCEL)
|
|
|
+ return super().cancel()
|
|
|
|
|
|
def set_running_or_notify_cancel(self):
|
|
|
- if self._state == base.PENDING:
|
|
|
- self._state = base.RUNNING
|
|
|
- return True
|
|
|
- elif self._state == base.CANCELLED:
|
|
|
- return False
|
|
|
- else:
|
|
|
- raise InvalidStateError(
|
|
|
- f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
|
|
|
- )
|
|
|
+ """if synchronize is set to False, this future will ignore any state changes from origin"""
|
|
|
+ self._synchronize_if_necessary()
|
|
|
+ try:
|
|
|
+ is_running = super().set_running_or_notify_cancel()
|
|
|
+ if is_running and os.getpid() != self._origin_pid:
|
|
|
+ self._send_update(MessageType.RUNNING)
|
|
|
+ return is_running
|
|
|
+ except RuntimeError as e:
|
|
|
+ raise InvalidStateError(str(e))
|
|
|
|
|
|
def result(self, timeout: Optional[float] = None) -> ResultType:
|
|
|
if self._state not in TERMINAL_STATES:
|
|
|
if os.getpid() != self._origin_pid:
|
|
|
raise RuntimeError("Only the process that created MPFuture can await result")
|
|
|
- return super().result(timeout)
|
|
|
- elif self._state == base.CANCELLED:
|
|
|
- raise base.CancelledError()
|
|
|
- elif self._exception:
|
|
|
- raise self._exception
|
|
|
- else:
|
|
|
- return self._result
|
|
|
+ return super().result(timeout)
|
|
|
|
|
|
def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
|
|
|
if self._state not in TERMINAL_STATES:
|
|
|
if os.getpid() != self._origin_pid:
|
|
|
raise RuntimeError("Only the process that created MPFuture can await exception")
|
|
|
- return super().exception(timeout)
|
|
|
- elif self._state == base.CANCELLED:
|
|
|
- raise base.CancelledError()
|
|
|
- return self._exception
|
|
|
+ return super().exception(timeout)
|
|
|
|
|
|
def done(self) -> bool:
|
|
|
+ self._synchronize_if_necessary()
|
|
|
return self._state in TERMINAL_STATES
|
|
|
|
|
|
def running(self):
|
|
|
+ self._synchronize_if_necessary()
|
|
|
return self._state == base.RUNNING
|
|
|
|
|
|
def cancelled(self):
|
|
|
+ self._synchronize_if_necessary()
|
|
|
return self._state == base.CANCELLED
|
|
|
|
|
|
def add_done_callback(self, callback: Callable[[MPFuture], None]):
|
|
@@ -240,7 +308,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
raise RuntimeError("Can't await: MPFuture was created with no event loop")
|
|
|
yield from self._aio_event.wait().__await__()
|
|
|
try:
|
|
|
- return super().result(timeout=0)
|
|
|
+ return super().result()
|
|
|
except base.CancelledError:
|
|
|
raise asyncio.CancelledError()
|
|
|
|
|
@@ -252,8 +320,9 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
|
|
|
def __getstate__(self):
|
|
|
return dict(
|
|
|
+ synchronize=self.synchronize,
|
|
|
_sender_pipe=self._sender_pipe,
|
|
|
- _shared_state_code=self._shared_state_code,
|
|
|
+ _state=self._state,
|
|
|
_origin_pid=self._origin_pid,
|
|
|
_uid=self._uid,
|
|
|
_use_lock=self._use_lock,
|
|
@@ -262,13 +331,12 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
)
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
+ self.synchronize = state["synchronize"]
|
|
|
self._sender_pipe = state["_sender_pipe"]
|
|
|
- self._shared_state_code = state["_shared_state_code"]
|
|
|
- self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
|
|
|
+ self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
|
|
|
self._result, self._exception = state["_result"], state["_exception"]
|
|
|
self._use_lock = state["_use_lock"]
|
|
|
|
|
|
self._waiters, self._done_callbacks = [], []
|
|
|
self._condition = threading.Condition()
|
|
|
self._aio_event, self._loop = None, None
|
|
|
- self._state_cache = {}
|