justheuristic 4 gadi atpakaļ
vecāks
revīzija
49b21cd8ab
1 mainītis faili ar 106 papildinājumiem un 214 dzēšanām
  1. 106 214
      hivemind/utils/mpfuture.py

+ 106 - 214
hivemind/utils/mpfuture.py

@@ -2,15 +2,16 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
-from contextlib import nullcontext, suppress
+from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing as mp
+import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
-from queue import SimpleQueue
-from weakref import ref
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable
+
+import torch  # used for py3.7-compatible shared memory
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -33,13 +34,10 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
 
 
-class MessageType(Enum):
+class UpdateType(Enum):
     RESULT = auto()
     RESULT = auto()
     EXCEPTION = auto()
     EXCEPTION = auto()
-    RUNNING = auto()
     CANCEL = auto()
     CANCEL = auto()
-    STATE_REQUEST = auto()
-    STATE_RESPONSE = auto()
 
 
 
 
 class MPFuture(base.Future, Generic[ResultType]):
 class MPFuture(base.Future, Generic[ResultType]):
@@ -48,8 +46,6 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
 
-    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
-      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
@@ -64,38 +60,49 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
-    _pending_updates: Optional[SimpleQueue] = None  # a queue of updates to be processed by background thread
-    _update_reading_thread: Optional[threading.Thread] = None  # process-specific thread that reads updates from pipe
-    _update_processing_thread: Optional[threading.Thread] = None  # process-specific thread that processes updates
-    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
-    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
+    _active_futures: Optional[Dict[UID, MPFuture]] = None  # pending or running futures originated from current process
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
-    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
-
-    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
-        super().__init__()
-        self.synchronize = synchronize
+    def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
+
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
         self._use_lock = use_lock
 
 
-        self._initialize_backend_if_necessary()
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         assert self._uid not in MPFuture._active_futures
-        MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        MPFuture._active_futures[self._uid] = self
+        self._sender_pipe = MPFuture._global_sender_pipe
 
 
         try:
         try:
-            self._loop = asyncio.get_event_loop()
+            self._loop = loop or asyncio.get_event_loop()
             self._aio_event = asyncio.Event()
             self._aio_event = asyncio.Event()
         except RuntimeError:
         except RuntimeError:
             self._loop, self._aio_event = None, None
             self._loop, self._aio_event = None, None
 
 
-    def _set_event_if_necessary(self):
-        if self._aio_event is None or self._aio_event.is_set():
-            return
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
         try:
             loop = asyncio.get_running_loop()
             loop = asyncio.get_running_loop()
         except RuntimeError:
         except RuntimeError:
@@ -104,235 +111,120 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
         async def _event_setter():
             self._aio_event.set()
             self._aio_event.set()
 
 
-        if self._loop.is_running() and loop == self.get_loop():
+        if loop == self.get_loop():
             asyncio.create_task(_event_setter())
             asyncio.create_task(_event_setter())
-        elif self._loop.is_running() and loop != self.get_loop():
-            asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
         else:
         else:
-            self._loop.run_until_complete(_event_setter())
-
-    @classmethod
-    def _initialize_backend_if_necessary(cls):
-        pid = os.getpid()
-        if MPFuture._active_pid != pid:
-            with MPFuture._initialization_lock:
-                if MPFuture._active_pid != pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
-                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
-                    cls._pending_updates = SimpleQueue()
-
-                    cls._update_reading_thread = threading.Thread(
-                        target=cls._read_updates_in_background,
-                        args=[receiver_pipe],
-                        name=f"{__name__}.READER",
-                        daemon=True,
-                    )
-                    cls._update_reading_thread.start()
-                    cls._update_processing_thread = threading.Thread(
-                        target=cls._process_updates_in_background,
-                        name=f"{__name__}.PROCESSOR",
-                        daemon=True,
-                    )
-                    cls._update_processing_thread.start()
-
-    @classmethod
-    def reset_backend(cls):
-        """
-        Reset the MPFuture backend. This is useful when the state may have been corrupted
-        (e.g. killing child processes may leave the locks acquired and the background thread blocked).
-
-        This method is neither thread-safe nor process-safe.
-        """
-
-        cls._initialization_lock = mp.Lock()
-        cls._update_lock = mp.Lock()
-        cls._active_pid = None
+            asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
 
 
     @classmethod
     @classmethod
-    def _read_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
+    def _initialize_mpfuture_backend(cls):
         pid = os.getpid()
         pid = os.getpid()
-        while True:
-            if cls._update_reading_thread is not threading.current_thread():
-                break  # Backend was reset, a new background thread has started
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
+        assert pid != cls._active_pid, "already initialized"
 
 
-            try:
-                cls._pending_updates.put(receiver_pipe.recv())
-            except (BrokenPipeError, EOFError, ConnectionError):
-                logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
-            except Exception as e:
-                logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
+        cls._pipe_waiter_thread.start()
 
 
     @classmethod
     @classmethod
-    def _process_updates_in_background(cls):
+    def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         pid = os.getpid()
         while True:
         while True:
             try:
             try:
-                if cls._update_processing_thread is not threading.current_thread():
-                    break  # Backend was reset, a new background thread has started
-
-                uid, msg_type, payload = cls._pending_updates.get()
-                future = None
-                future_ref = cls._active_futures.get(uid)
-                if future_ref is not None:
-                    future = future_ref()
-
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
-                        else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
-
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
-                    future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
-                    future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
-                    future.cancel()
+                uid, update_type, payload = receiver_pipe.recv()
+                if uid not in cls._active_futures:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
+                    cls._active_futures.pop(uid).set_result(payload)
+                elif update_type == UpdateType.EXCEPTION:
+                    cls._active_futures.pop(uid).set_exception(payload)
+                elif update_type == UpdateType.CANCEL:
+                    cls._active_futures.pop(uid).cancel()
                 else:
                 else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
-            except (BrokenPipeError, EOFError, ConnectionError):
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
+            except (BrokenPipeError, EOFError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
-    def _send_update(self, update_type: MessageType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                self._sender_pipe.send((self._uid, update_type, payload))
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
-
-    def _synchronize_if_necessary(self):
-        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
-            return
-
-        self._initialize_backend_if_necessary()
-
-        status_updated = threading.Event()
-        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
-        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
-
-        if existing_status_event != status_updated:
-            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
-            return
-
-        # otherwise create a new request for synchronization
-
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
-            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
-            if not status_updated.is_set():
-                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
-                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
-                if not status_updated.is_set() and not self.cancel():
-                    with suppress(InvalidStateError, RuntimeError):
-                        self.set_exception(
-                            TimeoutError(
-                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
-                                f"MPFuture is cancelled"
-                            )
-                        )
-                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
-            if not self.cancel():
-                with suppress(InvalidStateError, RuntimeError):
-                    self.set_exception(e)
-        finally:
-            self._status_requests.pop(self._uid, None)
+        with MPFuture._update_lock if self._use_lock else nullcontext():
+            self._sender_pipe.send((self._uid, update_type, payload))
 
 
     def set_result(self, result: ResultType):
     def set_result(self, result: ResultType):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.RESULT, result)
-        super().set_result(result)
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
 
 
     def set_exception(self, exception: Optional[BaseException]):
     def set_exception(self, exception: Optional[BaseException]):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.EXCEPTION, exception)
-        super().set_exception(exception)
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:
-        if self._state in [base.RUNNING, base.FINISHED]:
-            return False
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
         else:
         else:
-            self._send_update(MessageType.CANCEL)
-        return super().cancel()
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        """if synchronize is set to False, this future will ignore any state changes from origin"""
-        self._synchronize_if_necessary()
-        try:
-            is_running = super().set_running_or_notify_cancel()
-            if is_running and os.getpid() != self._origin_pid:
-                self._send_update(MessageType.RUNNING)
-            return is_running
-        except RuntimeError as e:
-            raise InvalidStateError(str(e))
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return True
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
                 raise RuntimeError("Only the process that created MPFuture can await result")
-        return super().result(timeout)
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
+            raise self._exception
+        else:
+            return self._result
 
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-        return super().exception(timeout)
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        return self._exception
 
 
     def done(self) -> bool:
     def done(self) -> bool:
-        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
         return self._state in TERMINAL_STATES
 
 
     def running(self):
     def running(self):
-        self._synchronize_if_necessary()
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
-        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -348,7 +240,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
         yield from self._aio_event.wait().__await__()
         yield from self._aio_event.wait().__await__()
         try:
         try:
-            return super().result()
+            return super().result(timeout=0)
         except base.CancelledError:
         except base.CancelledError:
             raise asyncio.CancelledError()
             raise asyncio.CancelledError()
 
 
@@ -360,9 +252,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     def __getstate__(self):
     def __getstate__(self):
         return dict(
         return dict(
-            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
             _sender_pipe=self._sender_pipe,
-            _state=self._state,
+            _shared_state_code=self._shared_state_code,
             _origin_pid=self._origin_pid,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _uid=self._uid,
             _use_lock=self._use_lock,
             _use_lock=self._use_lock,
@@ -371,12 +262,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
         )
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
-        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
         self._sender_pipe = state["_sender_pipe"]
-        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
         self._use_lock = state["_use_lock"]
 
 
         self._waiters, self._done_callbacks = [], []
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
         self._aio_event, self._loop = None, None
+        self._state_cache = {}