Browse Source

use a separate pipe

justheuristic 4 years ago
parent
commit
2229f7ec08
1 changed files with 70 additions and 63 deletions
  1. 70 63
      hivemind/utils/mpfuture.py

+ 70 - 63
hivemind/utils/mpfuture.py

@@ -7,6 +7,7 @@ import multiprocessing as mp
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
+from selectors import DefaultSelector, EVENT_READ
 from weakref import ref
 from weakref import ref
 from enum import Enum, auto
 from enum import Enum, auto
 from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
 from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
@@ -62,8 +63,10 @@ class MPFuture(base.Future, Generic[ResultType]):
     """
     """
 
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
-    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing of results/exceptions through same pipe
+    _status_lock = mp.Lock()  # global lock that prevents simultaneous sening of status updates through same pipe
+    _process_inner_pipe: Optional[PipeEnd] = None  # a pipe that is used to read results and send status updates
+    _process_outer_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results and receive status updates
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
     _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
@@ -82,7 +85,7 @@ class MPFuture(base.Future, Generic[ResultType]):
         self._initialize_backend_if_necessary()
         self._initialize_backend_if_necessary()
         assert self._uid not in MPFuture._active_futures
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._pipe_to_origin = MPFuture._process_outer_pipe
 
 
         try:
         try:
             self._loop = asyncio.get_event_loop()
             self._loop = asyncio.get_event_loop()
@@ -116,11 +119,10 @@ class MPFuture(base.Future, Generic[ResultType]):
                 if MPFuture._active_pid != pid:
                 if MPFuture._active_pid != pid:
                     # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
                     # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
                     logger.debug(f"Initializing MPFuture backend for pid {pid}")
                     logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
+                    cls._process_inner_pipe, cls._process_outer_pipe = mp.Pipe(duplex=True)
                     cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
                     cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
                     cls._pipe_waiter_thread = threading.Thread(
                     cls._pipe_waiter_thread = threading.Thread(
                         target=cls._process_updates_in_background,
                         target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
                         name=f"{__name__}.BACKEND",
                         name=f"{__name__}.BACKEND",
                         daemon=True,
                         daemon=True,
                     )
                     )
@@ -140,67 +142,72 @@ class MPFuture(base.Future, Generic[ResultType]):
         cls._active_pid = None
         cls._active_pid = None
 
 
     @classmethod
     @classmethod
-    def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
+    def _process_updates_in_background(cls):
         pid = os.getpid()
         pid = os.getpid()
-        while True:
-            try:
-                if cls._pipe_waiter_thread is not threading.current_thread():
-                    break  # Backend was reset, a new background thread has started
-
-                uid, msg_type, payload = receiver_pipe.recv()
-                future = None
-                future_ref = cls._active_futures.get(uid)
-                if future_ref is not None:
-                    future = future_ref()
-
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
+        with DefaultSelector() as selector:
+            selector.register(cls._process_inner_pipe, EVENT_READ, data=cls._process_inner_pipe)
+            selector.register(cls._process_outer_pipe, EVENT_READ, data=cls._process_outer_pipe)
+
+            while True:
+                try:
+                    if cls._pipe_waiter_thread is not threading.current_thread():
+                        break  # Backend was reset, a new background thread has started
+
+                    pipe = next((key.data for (key, events) in selector.select()))
+                    uid, msg_type, payload = pipe.recv()
+                    future = None
+                    future_ref = cls._active_futures.get(uid)
+                    if future_ref is not None:
+                        future = future_ref()
+
+                    if msg_type == MessageType.STATE_REQUEST:
+                        future_state = None if future is None else future.__getstate__()
+                        use_lock, return_pipe = payload
+                        with MPFuture._status_lock if use_lock else nullcontext():
+                            return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
+
+                    elif msg_type == MessageType.STATE_RESPONSE:
+                        future, state_updated_event = cls._status_requests.get(uid, (None, None))
+                        if future is None:
+                            logger.debug("Received a state update for a future that does not await status update.")
                         else:
                         else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
+                            if payload is not None:
+                                future.__setstate__(payload)
+                            else:
+                                base.Future.cancel(future)
+                            state_updated_event.set()
+
+                    elif future is None:
+                        logger.debug(
+                            f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
+                        )
+                    elif msg_type == MessageType.RESULT:
+                        future.set_result(payload)
+                    elif msg_type == MessageType.EXCEPTION:
+                        future.set_exception(payload)
+                    elif msg_type == MessageType.RUNNING:
+                        try:
+                            future.set_running_or_notify_cancel()
+                        except (InvalidStateError, RuntimeError) as e:
+                            logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
+                    elif msg_type == MessageType.CANCEL:
+                        future.cancel()
+                    else:
+                        raise RuntimeError(f"Received unexpected update type {msg_type}")
 
 
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
-                    future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
-                    future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
-                    future.cancel()
-                else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
-            except (BrokenPipeError, EOFError, ConnectionError):
-                logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
-            except Exception as e:
-                logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
+                    if future is None or future.done():
+                        cls._active_futures.pop(uid, None)
+
+                except (BrokenPipeError, EOFError, ConnectionError):
+                    logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
+                except Exception as e:
+                    logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
     def _send_update(self, update_type: MessageType, payload: Any = None):
     def _send_update(self, update_type: MessageType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
             with MPFuture._update_lock if self._use_lock else nullcontext():
-                self._sender_pipe.send((self._uid, update_type, payload))
+                self._pipe_to_origin.send((self._uid, update_type, payload))
         except (ConnectionError, BrokenPipeError, EOFError) as e:
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
 
@@ -221,9 +228,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         # otherwise create a new request for synchronization
         # otherwise create a new request for synchronization
 
 
         try:
         try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
+            with MPFuture._status_lock if self._use_lock else nullcontext():
+                payload = (self._use_lock, self._process_inner_pipe)
+                self._pipe_to_origin.send((self._uid, MessageType.STATE_REQUEST, payload))
             status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
             status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
             if not status_updated.is_set():
             if not status_updated.is_set():
                 logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
                 logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
@@ -336,7 +343,7 @@ class MPFuture(base.Future, Generic[ResultType]):
     def __getstate__(self):
     def __getstate__(self):
         return dict(
         return dict(
             synchronize=self.synchronize,
             synchronize=self.synchronize,
-            _sender_pipe=self._sender_pipe,
+            _pipe_to_origin=self._pipe_to_origin,
             _state=self._state,
             _state=self._state,
             _origin_pid=self._origin_pid,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _uid=self._uid,
@@ -347,7 +354,7 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
         self.synchronize = state["synchronize"]
         self.synchronize = state["synchronize"]
-        self._sender_pipe = state["_sender_pipe"]
+        self._pipe_to_origin = state["_pipe_to_origin"]
         self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
         self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
         self._use_lock = state["_use_lock"]