|
@@ -7,6 +7,7 @@ import multiprocessing as mp
|
|
|
import os
|
|
|
import threading
|
|
|
import uuid
|
|
|
+from selectors import DefaultSelector, EVENT_READ
|
|
|
from weakref import ref
|
|
|
from enum import Enum, auto
|
|
|
from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
|
|
@@ -62,8 +63,10 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
"""
|
|
|
|
|
|
_initialization_lock = mp.Lock() # global lock that prevents simultaneous initialization of two processes
|
|
|
- _update_lock = mp.Lock() # global lock that prevents simultaneous writing to the same pipe
|
|
|
- _process_wide_pipe: Optional[PipeEnd] = None # a pipe that is used to send results/exceptions to this process
|
|
|
+ _update_lock = mp.Lock() # global lock that prevents simultaneous writing of results/exceptions through same pipe
|
|
|
+ _status_lock = mp.Lock() # global lock that prevents simultaneous sening of status updates through same pipe
|
|
|
+ _process_inner_pipe: Optional[PipeEnd] = None # a pipe that is used to read results and send status updates
|
|
|
+ _process_outer_pipe: Optional[PipeEnd] = None # a pipe that is used to send results and receive status updates
|
|
|
_pipe_waiter_thread: Optional[threading.Thread] = None # process-specific thread that receives results/exceptions
|
|
|
_active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None # non-done futures originated from this process
|
|
|
_status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None # futures to be updated by origin
|
|
@@ -82,7 +85,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
self._initialize_backend_if_necessary()
|
|
|
assert self._uid not in MPFuture._active_futures
|
|
|
MPFuture._active_futures[self._uid] = ref(self)
|
|
|
- self._sender_pipe = MPFuture._process_wide_pipe
|
|
|
+ self._pipe_to_origin = MPFuture._process_outer_pipe
|
|
|
|
|
|
try:
|
|
|
self._loop = asyncio.get_event_loop()
|
|
@@ -116,11 +119,10 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
if MPFuture._active_pid != pid:
|
|
|
# note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
|
|
|
logger.debug(f"Initializing MPFuture backend for pid {pid}")
|
|
|
- receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
|
|
|
+ cls._process_inner_pipe, cls._process_outer_pipe = mp.Pipe(duplex=True)
|
|
|
cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
|
|
|
cls._pipe_waiter_thread = threading.Thread(
|
|
|
target=cls._process_updates_in_background,
|
|
|
- args=[receiver_pipe],
|
|
|
name=f"{__name__}.BACKEND",
|
|
|
daemon=True,
|
|
|
)
|
|
@@ -140,67 +142,72 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
cls._active_pid = None
|
|
|
|
|
|
@classmethod
|
|
|
- def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
|
|
|
+ def _process_updates_in_background(cls):
|
|
|
pid = os.getpid()
|
|
|
- while True:
|
|
|
- try:
|
|
|
- if cls._pipe_waiter_thread is not threading.current_thread():
|
|
|
- break # Backend was reset, a new background thread has started
|
|
|
-
|
|
|
- uid, msg_type, payload = receiver_pipe.recv()
|
|
|
- future = None
|
|
|
- future_ref = cls._active_futures.get(uid)
|
|
|
- if future_ref is not None:
|
|
|
- future = future_ref()
|
|
|
-
|
|
|
- if msg_type == MessageType.STATE_REQUEST:
|
|
|
- future_state = None if future is None else future.__getstate__()
|
|
|
- use_lock, return_pipe = payload
|
|
|
- with MPFuture._update_lock if use_lock else nullcontext():
|
|
|
- return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
|
|
|
-
|
|
|
- elif msg_type == MessageType.STATE_RESPONSE:
|
|
|
- future, state_updated_event = cls._status_requests.get(uid, (None, None))
|
|
|
- if future is None:
|
|
|
- logger.debug("Received a state update for a future that does not await status update.")
|
|
|
- else:
|
|
|
- if payload is not None:
|
|
|
- future.__setstate__(payload)
|
|
|
+ with DefaultSelector() as selector:
|
|
|
+ selector.register(cls._process_inner_pipe, EVENT_READ, data=cls._process_inner_pipe)
|
|
|
+ selector.register(cls._process_outer_pipe, EVENT_READ, data=cls._process_outer_pipe)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ if cls._pipe_waiter_thread is not threading.current_thread():
|
|
|
+ break # Backend was reset, a new background thread has started
|
|
|
+
|
|
|
+ pipe = next((key.data for (key, events) in selector.select()))
|
|
|
+ uid, msg_type, payload = pipe.recv()
|
|
|
+ future = None
|
|
|
+ future_ref = cls._active_futures.get(uid)
|
|
|
+ if future_ref is not None:
|
|
|
+ future = future_ref()
|
|
|
+
|
|
|
+ if msg_type == MessageType.STATE_REQUEST:
|
|
|
+ future_state = None if future is None else future.__getstate__()
|
|
|
+ use_lock, return_pipe = payload
|
|
|
+ with MPFuture._status_lock if use_lock else nullcontext():
|
|
|
+ return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
|
|
|
+
|
|
|
+ elif msg_type == MessageType.STATE_RESPONSE:
|
|
|
+ future, state_updated_event = cls._status_requests.get(uid, (None, None))
|
|
|
+ if future is None:
|
|
|
+ logger.debug("Received a state update for a future that does not await status update.")
|
|
|
else:
|
|
|
- base.Future.cancel(future)
|
|
|
- state_updated_event.set()
|
|
|
+ if payload is not None:
|
|
|
+ future.__setstate__(payload)
|
|
|
+ else:
|
|
|
+ base.Future.cancel(future)
|
|
|
+ state_updated_event.set()
|
|
|
+
|
|
|
+ elif future is None:
|
|
|
+ logger.debug(
|
|
|
+ f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
|
|
|
+ )
|
|
|
+ elif msg_type == MessageType.RESULT:
|
|
|
+ future.set_result(payload)
|
|
|
+ elif msg_type == MessageType.EXCEPTION:
|
|
|
+ future.set_exception(payload)
|
|
|
+ elif msg_type == MessageType.RUNNING:
|
|
|
+ try:
|
|
|
+ future.set_running_or_notify_cancel()
|
|
|
+ except (InvalidStateError, RuntimeError) as e:
|
|
|
+ logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
|
|
|
+ elif msg_type == MessageType.CANCEL:
|
|
|
+ future.cancel()
|
|
|
+ else:
|
|
|
+ raise RuntimeError(f"Received unexpected update type {msg_type}")
|
|
|
|
|
|
- elif future is None:
|
|
|
- logger.debug(
|
|
|
- f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
|
|
|
- )
|
|
|
- elif msg_type == MessageType.RESULT:
|
|
|
- future.set_result(payload)
|
|
|
- elif msg_type == MessageType.EXCEPTION:
|
|
|
- future.set_exception(payload)
|
|
|
- elif msg_type == MessageType.RUNNING:
|
|
|
- try:
|
|
|
- future.set_running_or_notify_cancel()
|
|
|
- except (InvalidStateError, RuntimeError) as e:
|
|
|
- logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
|
|
|
- elif msg_type == MessageType.CANCEL:
|
|
|
- future.cancel()
|
|
|
- else:
|
|
|
- raise RuntimeError(f"Received unexpected update type {msg_type}")
|
|
|
-
|
|
|
- if future is None or future.done():
|
|
|
- cls._active_futures.pop(uid, None)
|
|
|
-
|
|
|
- except (BrokenPipeError, EOFError, ConnectionError):
|
|
|
- logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
|
|
|
- except Exception as e:
|
|
|
- logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
|
|
|
+ if future is None or future.done():
|
|
|
+ cls._active_futures.pop(uid, None)
|
|
|
+
|
|
|
+ except (BrokenPipeError, EOFError, ConnectionError):
|
|
|
+ logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
|
|
|
+ except Exception as e:
|
|
|
+ logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
|
|
|
|
|
|
def _send_update(self, update_type: MessageType, payload: Any = None):
|
|
|
"""This method sends result, exception or cancel to the MPFuture origin."""
|
|
|
try:
|
|
|
with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
- self._sender_pipe.send((self._uid, update_type, payload))
|
|
|
+ self._pipe_to_origin.send((self._uid, update_type, payload))
|
|
|
except (ConnectionError, BrokenPipeError, EOFError) as e:
|
|
|
logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
|
|
|
|
|
@@ -221,9 +228,9 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
# otherwise create a new request for synchronization
|
|
|
|
|
|
try:
|
|
|
- with MPFuture._update_lock if self._use_lock else nullcontext():
|
|
|
- payload = (self._use_lock, self._process_wide_pipe)
|
|
|
- self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
|
|
|
+ with MPFuture._status_lock if self._use_lock else nullcontext():
|
|
|
+ payload = (self._use_lock, self._process_inner_pipe)
|
|
|
+ self._pipe_to_origin.send((self._uid, MessageType.STATE_REQUEST, payload))
|
|
|
status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
|
|
|
if not status_updated.is_set():
|
|
|
logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
|
|
@@ -336,7 +343,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
def __getstate__(self):
|
|
|
return dict(
|
|
|
synchronize=self.synchronize,
|
|
|
- _sender_pipe=self._sender_pipe,
|
|
|
+ _pipe_to_origin=self._pipe_to_origin,
|
|
|
_state=self._state,
|
|
|
_origin_pid=self._origin_pid,
|
|
|
_uid=self._uid,
|
|
@@ -347,7 +354,7 @@ class MPFuture(base.Future, Generic[ResultType]):
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
self.synchronize = state["synchronize"]
|
|
|
- self._sender_pipe = state["_sender_pipe"]
|
|
|
+ self._pipe_to_origin = state["_pipe_to_origin"]
|
|
|
self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
|
|
|
self._result, self._exception = state["_result"], state["_exception"]
|
|
|
self._use_lock = state["_use_lock"]
|