Browse Source

Resolve deadlock in MPFuture (#337)

This PR partially rolls back MPFuture to an earlier state that still uses posix shared memory, with some additional safeguards that protect it against hanging on fork / termination.

Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 4 years ago
parent
commit
1b01a8164a
5 changed files with 144 additions and 185 deletions
  1. 1 1
      hivemind/moe/server/task_pool.py
  2. 134 172
      hivemind/utils/mpfuture.py
  3. 2 2
      tests/conftest.py
  4. 2 2
      tests/test_dht_node.py
  5. 5 8
      tests/test_util_modules.py

+ 1 - 1
hivemind/moe/server/task_pool.py

@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(synchronize=False), args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
             task.future.set_exception(exc)

+ 134 - 172
hivemind/utils/mpfuture.py

@@ -2,14 +2,17 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
-from contextlib import nullcontext, suppress
+from weakref import ref
+from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing as mp
+import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
-from weakref import ref
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type
+
+import torch  # used for py3.7-compatible shared memory
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -32,13 +35,38 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
 
 
-class MessageType(Enum):
+class SharedBytes:
+    """
+    A process-wide object that allocates large chunks of shared memory and partitions it into individual bytes.
+
+    Note: this process is only responsible for bulk allocation, it does not manage/free unused bytes.
+    The chunks are deallocated by the garbage collector,
+    when it detects that all processes no longer use any bytes from this chunk.
+    """
+
+    _lock = mp.Lock()
+    _pid: Optional[PID] = None
+    _buffer: Optional[torch.Tensor] = None
+    _index: int = 0
+
+    @classmethod
+    def next(cls) -> torch.Tensor:
+        """Create another shared byte value, represented as a scalar uint8 tensor"""
+        with cls._lock:
+            if cls._pid != os.getpid() or cls._buffer is None or cls._index >= len(cls._buffer):
+                buffer_size = os.environ.get("HIVEMIND_SHM_BUFFER_SIZE", 4096)
+                cls._pid = os.getpid()
+                cls._buffer = torch.empty([buffer_size], dtype=torch.uint8).share_memory_()
+                cls._index = 0
+
+            cls._index += 1
+            return cls._buffer[cls._index - 1]
+
+
+class UpdateType(Enum):
     RESULT = auto()
     RESULT = auto()
     EXCEPTION = auto()
     EXCEPTION = auto()
-    RUNNING = auto()
     CANCEL = auto()
     CANCEL = auto()
-    STATE_REQUEST = auto()
-    STATE_RESPONSE = auto()
 
 
 
 
 class MPFuture(base.Future, Generic[ResultType]):
 class MPFuture(base.Future, Generic[ResultType]):
@@ -47,12 +75,9 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
 
-    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
-      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
-    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
 
 
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
     :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
      More specifically, there are two known limitations:
      More specifically, there are two known limitations:
@@ -63,26 +88,30 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
-    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
+    _active_futures: Optional[Dict[UID, "ref[MPFuture]"]] = None  # non-done futures originated from this process
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    SOFT_UPDATE_TIMEOUT = 0.5  # seconds spent awaiting status update before warning is printed
-    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
-
-    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
-        super().__init__()
-        self.synchronize = synchronize
+    def __init__(self, *, use_lock: bool = True):
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = SharedBytes.next()
+        self._state_cache: Dict[State, State] = {}
+        # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
+
+        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
         self._use_lock = use_lock
 
 
-        self._initialize_backend_if_necessary()
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
         assert self._uid not in MPFuture._active_futures
         assert self._uid not in MPFuture._active_futures
         MPFuture._active_futures[self._uid] = ref(self)
         MPFuture._active_futures[self._uid] = ref(self)
-        self._sender_pipe = MPFuture._process_wide_pipe
+        self._sender_pipe = MPFuture._global_sender_pipe
 
 
         try:
         try:
             self._loop = asyncio.get_event_loop()
             self._loop = asyncio.get_event_loop()
@@ -90,54 +119,52 @@ class MPFuture(base.Future, Generic[ResultType]):
         except RuntimeError:
         except RuntimeError:
             self._loop, self._aio_event = None, None
             self._loop, self._aio_event = None, None
 
 
-    def _set_event_if_necessary(self):
-        if self._aio_event is None or self._aio_event.is_set():
-            return
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
         try:
-            loop = asyncio.get_running_loop()
+            running_loop = asyncio.get_running_loop()
         except RuntimeError:
         except RuntimeError:
-            loop = None
+            running_loop = None
 
 
         async def _event_setter():
         async def _event_setter():
             self._aio_event.set()
             self._aio_event.set()
 
 
-        if self._loop.is_running() and loop == self.get_loop():
+        if self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
             asyncio.create_task(_event_setter())
-        elif self._loop.is_running() and loop != self.get_loop():
+        elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
         else:
         else:
             self._loop.run_until_complete(_event_setter())
             self._loop.run_until_complete(_event_setter())
 
 
     @classmethod
     @classmethod
-    def _initialize_backend_if_necessary(cls):
+    def _initialize_mpfuture_backend(cls):
         pid = os.getpid()
         pid = os.getpid()
-        if MPFuture._active_pid != pid:
-            with MPFuture._initialization_lock:
-                if MPFuture._active_pid != pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
-                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
-                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
-                    cls._pipe_waiter_thread = threading.Thread(
-                        target=cls._process_updates_in_background,
-                        args=[receiver_pipe],
-                        name=f"{__name__}.BACKEND",
-                        daemon=True,
-                    )
-                    cls._pipe_waiter_thread.start()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
 
 
-    @classmethod
-    def reset_backend(cls):
-        """
-        Reset the MPFuture backend. This is useful when the state may have been corrupted
-        (e.g. killing child processes may leave the locks acquired and the background thread blocked).
-
-        This method is neither thread-safe nor process-safe.
-        """
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(
+            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
+        )
+        cls._pipe_waiter_thread.start()
 
 
-        cls._initialization_lock = mp.Lock()
-        cls._update_lock = mp.Lock()
-        cls._active_pid = None
+    @staticmethod
+    def reset_backend():
+        """Last-resort function to reset internals of MPFuture. All current MPFuture instances will be broken"""
+        MPFuture._active_pid = None
+        MPFuture._initialization_lock = mp.Lock()
+        MPFuture._update_lock = mp.Lock()
+        SharedBytes._lock = mp.Lock()
 
 
     @classmethod
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
@@ -145,58 +172,30 @@ class MPFuture(base.Future, Generic[ResultType]):
         while True:
         while True:
             try:
             try:
                 if cls._pipe_waiter_thread is not threading.current_thread():
                 if cls._pipe_waiter_thread is not threading.current_thread():
-                    break  # Backend was reset, a new background thread has started
+                    break  # backend was reset, a new background thread has started
 
 
-                uid, msg_type, payload = receiver_pipe.recv()
+                uid, update_type, payload = receiver_pipe.recv()
                 future = None
                 future = None
-                future_ref = cls._active_futures.get(uid)
+                future_ref = cls._active_futures.pop(uid, None)
                 if future_ref is not None:
                 if future_ref is not None:
                     future = future_ref()
                     future = future_ref()
 
 
-                if msg_type == MessageType.STATE_REQUEST:
-                    future_state = None if future is None else future.__getstate__()
-                    use_lock, return_pipe = payload
-                    with MPFuture._update_lock if use_lock else nullcontext():
-                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
-
-                elif msg_type == MessageType.STATE_RESPONSE:
-                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
-                    if future is None:
-                        logger.debug("Received a state update for a future that does not await status update.")
-                    else:
-                        if payload is not None:
-                            future.__setstate__(payload)
-                        else:
-                            base.Future.cancel(future)
-                        state_updated_event.set()
-
-                elif future is None:
-                    logger.debug(
-                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
-                    )
-                elif msg_type == MessageType.RESULT:
+                if future is None:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
                     future.set_result(payload)
                     future.set_result(payload)
-                elif msg_type == MessageType.EXCEPTION:
+                elif update_type == UpdateType.EXCEPTION:
                     future.set_exception(payload)
                     future.set_exception(payload)
-                elif msg_type == MessageType.RUNNING:
-                    try:
-                        future.set_running_or_notify_cancel()
-                    except (InvalidStateError, RuntimeError) as e:
-                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
-                elif msg_type == MessageType.CANCEL:
+                elif update_type == UpdateType.CANCEL:
                     future.cancel()
                     future.cancel()
                 else:
                 else:
-                    raise RuntimeError(f"Received unexpected update type {msg_type}")
-
-                if future is None or future.done():
-                    cls._active_futures.pop(uid, None)
-
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
             except (BrokenPipeError, EOFError, ConnectionError):
             except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
-    def _send_update(self, update_type: MessageType, payload: Any = None):
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
         try:
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
             with MPFuture._update_lock if self._use_lock else nullcontext():
@@ -204,110 +203,76 @@ class MPFuture(base.Future, Generic[ResultType]):
         except (ConnectionError, BrokenPipeError, EOFError) as e:
         except (ConnectionError, BrokenPipeError, EOFError) as e:
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
             logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
 
 
-    def _synchronize_if_necessary(self):
-        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
-            return
-
-        self._initialize_backend_if_necessary()
-
-        status_updated = threading.Event()
-        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
-        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
-
-        if existing_status_event != status_updated:
-            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
-            return
-
-        # otherwise create a new request for synchronization
-
-        try:
-            with MPFuture._update_lock if self._use_lock else nullcontext():
-                payload = (self._use_lock, self._process_wide_pipe)
-                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
-            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
-            if not status_updated.is_set():
-                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
-                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
-                if not status_updated.is_set() and not self.cancel():
-                    with suppress(InvalidStateError, RuntimeError):
-                        self.set_exception(
-                            TimeoutError(
-                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
-                                f"MPFuture is cancelled"
-                            )
-                        )
-                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
-            if not self.cancel():
-                with suppress(InvalidStateError, RuntimeError):
-                    self.set_exception(e)
-        finally:
-            self._status_requests.pop(self._uid, None)
-
     def set_result(self, result: ResultType):
     def set_result(self, result: ResultType):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.RESULT, result)
-        super().set_result(result)
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
 
 
     def set_exception(self, exception: Optional[BaseException]):
     def set_exception(self, exception: Optional[BaseException]):
-        if self._state in TERMINAL_STATES:
-            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
         else:
         else:
-            self._send_update(MessageType.EXCEPTION, exception)
-        super().set_exception(exception)
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:
-        if self._state in [base.RUNNING, base.FINISHED]:
-            return False
-        elif os.getpid() == self._origin_pid:
+        if os.getpid() == self._origin_pid:
             MPFuture._active_futures.pop(self._uid, None)
             MPFuture._active_futures.pop(self._uid, None)
-            self._set_event_if_necessary()
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
         else:
         else:
-            self._send_update(MessageType.CANCEL)
-        return super().cancel()
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        """if synchronize is set to False, this future will ignore any state changes from origin"""
-        self._synchronize_if_necessary()
-        try:
-            is_running = super().set_running_or_notify_cancel()
-            if is_running and os.getpid() != self._origin_pid:
-                self._send_update(MessageType.RUNNING)
-            return is_running
-        except RuntimeError as e:
-            raise InvalidStateError(str(e))
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return True
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise InvalidStateError(
+                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
+            )
 
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
                 raise RuntimeError("Only the process that created MPFuture can await result")
-        return super().result(timeout)
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
+            raise self._exception
+        else:
+            return self._result
 
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-        return super().exception(timeout)
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        return self._exception
 
 
     def done(self) -> bool:
     def done(self) -> bool:
-        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
         return self._state in TERMINAL_STATES
 
 
     def running(self):
     def running(self):
-        self._synchronize_if_necessary()
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
-        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -315,9 +280,6 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
             raise RuntimeError("Only the process that created MPFuture can set callbacks")
         return super().add_done_callback(callback)
         return super().add_done_callback(callback)
 
 
-    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
-        return self._loop
-
     def __await__(self):
     def __await__(self):
         if not self._aio_event:
         if not self._aio_event:
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
@@ -335,9 +297,8 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     def __getstate__(self):
     def __getstate__(self):
         return dict(
         return dict(
-            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
             _sender_pipe=self._sender_pipe,
-            _state=self._state,
+            _shared_state_code=self._shared_state_code,
             _origin_pid=self._origin_pid,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _uid=self._uid,
             _use_lock=self._use_lock,
             _use_lock=self._use_lock,
@@ -346,12 +307,13 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
         )
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
-        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
         self._sender_pipe = state["_sender_pipe"]
-        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
+        self._shared_state_code = state["_shared_state_code"]
+        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
         self._use_lock = state["_use_lock"]
 
 
         self._waiters, self._done_callbacks = [], []
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
         self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 2 - 2
tests/conftest.py

@@ -1,11 +1,12 @@
 import gc
 import gc
 from contextlib import suppress
 from contextlib import suppress
+import multiprocessing as mp
 
 
 import psutil
 import psutil
 import pytest
 import pytest
 
 
+from hivemind.utils.mpfuture import MPFuture, SharedBytes
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
-from hivemind.utils.mpfuture import MPFuture
 
 
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -28,5 +29,4 @@ def cleanup_children():
             with suppress(psutil.NoSuchProcess):
             with suppress(psutil.NoSuchProcess):
                 child.kill()
                 child.kill()
 
 
-    # Broken code or killing of child processes may leave the MPFuture backend corrupted
     MPFuture.reset_backend()
     MPFuture.reset_backend()

+ 2 - 2
tests/test_dht_node.py

@@ -260,10 +260,10 @@ def test_dht_node(
         jaccard_denominator += k_nearest
         jaccard_denominator += k_nearest
 
 
     accuracy = accuracy_numerator / accuracy_denominator
     accuracy = accuracy_numerator / accuracy_denominator
-    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 98-100%
+    logger.debug(f"Top-1 accuracy: {accuracy}")  # should be 90-100%
     jaccard_index = jaccard_numerator / jaccard_denominator
     jaccard_index = jaccard_numerator / jaccard_denominator
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
     logger.debug(f"Jaccard index (intersection over union): {jaccard_index}")  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
     assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
 
     # test 4: find all nodes
     # test 4: find all nodes

+ 5 - 8
tests/test_util_modules.py

@@ -256,8 +256,8 @@ def test_mpfuture_done_callback():
 
 
     assert future1.done() and not future1.cancelled()
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
     assert future2.done() and future2.cancelled()
-    events[0].wait(1)
-    events[1].wait(1)
+    for i in 0, 1, 4:
+        events[i].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
     assert not events[3].is_set()
 
 
@@ -266,15 +266,14 @@ def test_mpfuture_done_callback():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-@pytest.mark.parametrize("synchronize", [True, False])
-def test_many_futures(synchronize: bool):
+def test_many_futures():
     evt = mp.Event()
     evt = mp.Event()
     receiver, sender = mp.Pipe()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
     assert len(hivemind.MPFuture._active_futures) == 1000
 
 
     def _run_peer():
     def _run_peer():
-        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
         assert len(hivemind.MPFuture._active_futures) == 500
 
 
         for i, future in enumerate(random.sample(main_futures, 300)):
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -299,8 +298,6 @@ def test_many_futures(synchronize: bool):
     p.start()
     p.start()
 
 
     some_fork_futures = receiver.recv()
     some_fork_futures = receiver.recv()
-
-    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
     assert len(hivemind.MPFuture._active_futures) == 700
 
 
     for future in some_fork_futures:
     for future in some_fork_futures: