justheuristic 4 роки тому
батько
коміт
2229f7ec08
1 змінених файлів з 70 додано та 63 видалено
  1. 70 63
      hivemind/utils/mpfuture.py

+ 70 - 63
hivemind/utils/mpfuture.py

@@ -7,6 +7,7 @@ import multiprocessing as mp
 import os
 import threading
 import uuid
+from selectors import DefaultSelector, EVENT_READ
 from weakref import ref
 from enum import Enum, auto
 from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
@@ -62,8 +63,10 @@ class MPFuture(base.Future, Generic[ResultType]):
     """
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
-    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing of results/exceptions through same pipe
+    _status_lock = mp.Lock()  # global lock that prevents simultaneous sening of status updates through same pipe
+    _process_inner_pipe: Optional[PipeEnd] = None  # a pipe that is used to read results and send status updates
+    _process_outer_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results and receive status updates
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
@@ -82,7 +85,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         self._initialize_backend_if_necessary()
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._pipe_to_origin = MPFuture._process_outer_pipe
 
         try:
             self._loop = asyncio.get_event_loop()
@@ -116,11 +119,10 @@ class MPFuture(base.Future, Generic[ResultType]):
                 if MPFuture._active_pid != pid:
                     # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
                     logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
+                    cls._process_inner_pipe, cls._process_outer_pipe = mp.Pipe(duplex=True)
                     cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
                     cls._pipe_waiter_thread = threading.Thread(
                         target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
                         name=f"{__name__}.BACKEND",
                         daemon=True,
                     )
@@ -140,67 +142,72 @@ class MPFuture(base.Future, Generic[ResultType]):
         cls._active_pid = None
 
     @classmethod
-    def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
+    def _process_updates_in_background(cls):
         pid = os.getpid()
-        while True:
-            try:
-                if cls._pipe_waiter_thread is not threading.current_thread():
-                    break  # Backend was reset, a new background thread has started
-
-                uid, msg_type, payload = receiver_pipe.recv()
-                future = None
-                future_ref = cls._active_futures.get(uid)
-                if future_ref is not None:
-                    future = future_ref()
-
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
+        with DefaultSelector() as selector:
+            selector.register(cls._process_inner_pipe, EVENT_READ, data=cls._process_inner_pipe)
+            selector.register(cls._process_outer_pipe, EVENT_READ, data=cls._process_outer_pipe)
+
+            while True:
+                try:
+                    if cls._pipe_waiter_thread is not threading.current_thread():
+                        break  # Backend was reset, a new background thread has started
+
+                    pipe = next((key.data for (key, events) in selector.select()))
+                    uid, msg_type, payload = pipe.recv()
+                    future = None
+                    future_ref = cls._active_futures.get(uid)
+                    if future_ref is not None:
+                        future = future_ref()
+
+                    if msg_type == MessageType.STATE_REQUEST:
+                        future_state = None if future is None else future.__getstate__()
+                        use_lock, return_pipe = payload
+                        with MPFuture._status_lock if use_lock else nullcontext():
+                            return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
+
+                    elif msg_type == MessageType.STATE_RESPONSE:
+                        future, state_updated_event = cls._status_requests.get(uid, (None, None))
+                        if future is None:
+                            logger.debug("Received a state update for a future that does not await status update.")
                         else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
+                            if payload is not None:
+                                future.__setstate__(payload)
+                            else:
+                                base.Future.cancel(future)
+                            state_updated_event.set()
+
+                    elif future is None:
+                        logger.debug(
+                            f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
+                        )
+                    elif msg_type == MessageType.RESULT:
+                        future.set_result(payload)
+                    elif msg_type == MessageType.EXCEPTION:
+                        future.set_exception(payload)
+                    elif msg_type == MessageType.RUNNING:
+                        try:
+                            future.set_running_or_notify_cancel()
+                        except (InvalidStateError, RuntimeError) as e:
+                            logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
+                    elif msg_type == MessageType.CANCEL:
+                        future.cancel()
+                    else:
+                        raise RuntimeError(f"Received unexpected update type {msg_type}")
 
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
-                    future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
-                    future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
-                    future.cancel()
-                else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
-            except (BrokenPipeError, EOFError, ConnectionError):
-                logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
-            except Exception as e:
-                logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
+                    if future is None or future.done():
+                        cls._active_futures.pop(uid, None)
+
+                except (BrokenPipeError, EOFError, ConnectionError):
+                    logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
+                except Exception as e:
+                    logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
     def _send_update(self, update_type: MessageType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
-                self._sender_pipe.send((self._uid, update_type, payload))
+                self._pipe_to_origin.send((self._uid, update_type, payload))
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
@@ -221,9 +228,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         # otherwise create a new request for synchronization
 
         try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
+            with MPFuture._status_lock if self._use_lock else nullcontext():
+                payload = (self._use_lock, self._process_inner_pipe)
+                self._pipe_to_origin.send((self._uid, MessageType.STATE_REQUEST, payload))
             status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
             if not status_updated.is_set():
                 logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
@@ -336,7 +343,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __getstate__(self):
         return dict(
             synchronize=self.synchronize,
-            _sender_pipe=self._sender_pipe,
+            _pipe_to_origin=self._pipe_to_origin,
             _state=self._state,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
@@ -347,7 +354,7 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __setstate__(self, state):
         self.synchronize = state["synchronize"]
-        self._sender_pipe = state["_sender_pipe"]
+        self._pipe_to_origin = state["_pipe_to_origin"]
         self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]