Просмотр исходного кода

Resolve deadlock in MPFuture (#337)

This PR partially rolls back MPFuture to an earlier state that still uses posix shared memory, with some additional safeguards that protect it against hanging on fork / termination.

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 4 лет назад
Родитель
Сommit
1b01a8164a
5 измененных файлов с 144 добавлено и 185 удалено
  1. 1 1
      hivemind/moe/server/task_pool.py
  2. 134 172
      hivemind/utils/mpfuture.py
  3. 2 2
      tests/conftest.py
  4. 2 2
      tests/test_dht_node.py
  5. 5 8
      tests/test_util_modules.py

+ 1 - 1
hivemind/moe/server/task_pool.py

@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(synchronize=False), args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)

+ 134 - 172
hivemind/utils/mpfuture.py

@@ -2,14 +2,17 @@ from __future__ import annotations
 
 import asyncio
 import concurrent.futures._base as base
-from contextlib import nullcontext, suppress
+from weakref import ref
+from contextlib import nullcontext
 import multiprocessing as mp
+import multiprocessing.connection
 import os
 import threading
 import uuid
-from weakref import ref
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type
+
+import torch  # used for py3.7-compatible shared memory
 
 from hivemind.utils.logging import get_logger
 
@@ -32,13 +35,38 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
-class MessageType(Enum):
+class SharedBytes:
+    """
+    A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
+
+    Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
+    The chunks are deallocated by the garbage collector,
+    when it detects that all processes no longer use any bytes from this chunk.
+    """
+
+    _lock = mp.Lock()
+    _pid: Optional[PID] = None
+    _buffer: Optional[torch.Tensor] = None
+    _index: int = 0
+
+    @classmethod
+    def next(cls) -> torch.Tensor:
+        """Create another shared byte value, represented as a scalar uint8 tensor"""
+        with cls._lock:
+            if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
+                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                cls._pid = os.getpid()
+                cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
+                cls._index = 0
+
+            cls._index += 1
+            return cls._buffer[cls._index - 1]
+
+
+class UpdateType(Enum):
     RESULT = auto()
     EXCEPTION = auto()
-    RUNNING = auto()
     CANCEL = auto()
-    STATE_REQUEST = auto()
-    STATE_RESPONSE = auto()
 
 
 class MPFuture(base.Future, Generic[ResultType]):
@@ -47,12 +75,9 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
-    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
-      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
-    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
 
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
      More specifically, there are two known limitations:
@@ -63,26 +88,30 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
-    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
+    _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None  # non-done futures originated from this process
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
-    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
-    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
-
-    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
-        super().__init__()
-        self.synchronize = synchronize
+    def __init__(self, *, use_lock: bool = True):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = SharedBytes.next()
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
+
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
 
-        self._initialize_backend_if_necessary()
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._sender_pipe = MPFuture._global_sender_pipe
 
         try:
             self._loop = asyncio.get_event_loop()
@@ -90,54 +119,52 @@ class MPFuture(base.Future, Generic[ResultType]):
         except RuntimeError:
             self._loop, self._aio_event = None, None
 
-    def _set_event_if_necessary(self):
-        if self._aio_event is None or self._aio_event.is_set():
-            return
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
-            loop = asyncio.get_running_loop()
+            running_loop = asyncio.get_running_loop()
         except RuntimeError:
-            loop = None
+            running_loop = None
 
         async def _event_setter():
             self._aio_event.set()
 
-        if self._loop.is_running() and loop == self.get_loop():
+        if self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
-        elif self._loop.is_running() and loop != self.get_loop():
+        elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
         else:
             self._loop.run_until_complete(_event_setter())
 
     @classmethod
-    def _initialize_backend_if_necessary(cls):
+    def _initialize_mpfuture_backend(cls):
         pid = os.getpid()
-        if MPFuture._active_pid != pid:
-            with MPFuture._initialization_lock:
-                if MPFuture._active_pid != pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
-                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
-                    cls._pipe_waiter_thread = threading.Thread(
-                        target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
-                        name=f"{__name__}.BACKEND",
-                        daemon=True,
-                    )
-                    cls._pipe_waiter_thread.start()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
 
-    @classmethod
-    def reset_backend(cls):
-        """
-        Reset the MPFuture backend. This is useful when the state may have been corrupted
-        (e.g. killing child processes may leave the locks acquired and the background thread blocked).
-
-        This method is neither thread-safe nor process-safe.
-        """
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
+        cls._pipe_waiter_thread.start()
 
-        cls._initialization_lock = mp.Lock()
-        cls._update_lock = mp.Lock()
-        cls._active_pid = None
+    @staticmethod
+    def reset_backend():
+        """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
+        MPFuture._active_pid = None
+        MPFuture._initialization_lock = mp.Lock()
+        MPFuture._update_lock = mp.Lock()
+        SharedBytes._lock = mp.Lock()
 
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
@@ -145,58 +172,30 @@ class MPFuture(base.Future, Generic[ResultType]):
         while True:
             try:
                 if cls._pipe_waiter_thread is not threading.current_thread():
-                    break  # Backend was reset, a new background thread has started
+                    break  # backend was reset, a new background thread has started
 
-                uid, msg_type, payload = receiver_pipe.recv()
+                uid, update_type, payload = receiver_pipe.recv()
                 future = None
-                future_ref = cls._active_futures.get(uid)
+                future_ref = cls._active_futures.pop(uid, None)
                 if future_ref is not None:
                     future = future_ref()
 
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
-                        else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
-
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
+                if future is None:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
                     future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
+                elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
+                elif update_type == UpdateType.CANCEL:
                     future.cancel()
                 else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
             except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
-    def _send_update(self, update_type: MessageType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -204,110 +203,76 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
-    def _synchronize_if_necessary(self):
-        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
-            return
-
-        self._initialize_backend_if_necessary()
-
-        status_updated = threading.Event()
-        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
-        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
-
-        if existing_status_event != status_updated:
-            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
-            return
-
-        # otherwise create a new request for synchronization
-
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
-            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
-            if not status_updated.is_set():
-                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
-                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
-                if not status_updated.is_set() and not self.cancel():
-                    with suppress(InvalidStateError, RuntimeError):
-                        self.set_exception(
-                            TimeoutError(
-                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
-                                f"MPFuture is cancelled"
-                            )
-                        )
-                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
-            if not self.cancel():
-                with suppress(InvalidStateError, RuntimeError):
-                    self.set_exception(e)
-        finally:
-            self._status_requests.pop(self._uid, None)
-
     def set_result(self, result: ResultType):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
         else:
-            self._send_update(MessageType.RESULT, result)
-        super().set_result(result)
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
 
     def set_exception(self, exception: Optional[BaseException]):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
         else:
-            self._send_update(MessageType.EXCEPTION, exception)
-        super().set_exception(exception)
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
 
     def cancel(self) -> bool:
-        if self._state in [base.RUNNING, base.FINISHED]:
-            return False
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
         else:
-            self._send_update(MessageType.CANCEL)
-        return super().cancel()
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
     def set_running_or_notify_cancel(self):
-        """if synchronize is set to False, this future will ignore any state changes from origin"""
-        self._synchronize_if_necessary()
-        try:
-            is_running = super().set_running_or_notify_cancel()
-            if is_running and os.getpid() != self._origin_pid:
-                self._send_update(MessageType.RUNNING)
-            return is_running
-        except RuntimeError as e:
-            raise InvalidStateError(str(e))
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return True
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
-        return super().result(timeout)
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
+            raise self._exception
+        else:
+            return self._result
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-        return super().exception(timeout)
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        return self._exception
 
     def done(self) -> bool:
-        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
 
     def running(self):
-        self._synchronize_if_necessary()
         return self._state == base.RUNNING
 
     def cancelled(self):
-        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -315,9 +280,6 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
 
-    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
-        return self._loop
-
     def __await__(self):
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
@@ -335,9 +297,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __getstate__(self):
         return dict(
-            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
-            _state=self._state,
+            _shared_state_code=self._shared_state_code,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -346,12 +307,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
 
     def __setstate__(self, state):
-        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
-        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
 
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 2 - 2
tests/conftest.py

@@ -1,11 +1,12 @@
 import gc
 from contextlib import suppress
+import multiprocessing as mp
 
 import psutil
 import pytest
 
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 from hivemind.utils.logging import get_logger
-from hivemind.utils.mpfuture import MPFuture
 
 
 logger = get_logger(__name__)
@@ -28,5 +29,4 @@ def cleanup_children():
             with suppress(psutil.NoSuchProcess):
                 child.kill()
 
-    # Broken code or killing of child processes may leave the MPFuture backend corrupted
     MPFuture.reset_backend()

+ 2 - 2
tests/test_dht_node.py

@@ -260,10 +260,10 @@ def test_dht_node(
         jaccard_denominator += k_nearest
 
     accuracy = accuracy_numerator / accuracy_denominator
-    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
+    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 90-100%
     jaccard_index = jaccard_numerator / jaccard_denominator
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
     # test 4: find all nodes

+ 5 - 8
tests/test_util_modules.py

@@ -256,8 +256,8 @@ def test_mpfuture_done_callback():
 
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
-    events[0].wait(1)
-    events[1].wait(1)
+    for i in 0, 1, 4:
+        events[i].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
 
@@ -266,15 +266,14 @@ def test_mpfuture_done_callback():
 
 
 @pytest.mark.forked
-@pytest.mark.parametrize("synchronize", [True, False])
-def test_many_futures(synchronize: bool):
+def test_many_futures():
     evt = mp.Event()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
 
     def _run_peer():
-        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
 
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -299,8 +298,6 @@ def test_many_futures(synchronize: bool):
     p.start()
 
     some_fork_futures = receiver.recv()
-
-    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
 
     for future in some_fork_futures: