浏览代码

Remove shared memory from MPFuture, fix minor bugs (#317)

* re-written MPFuture to use pipe-only communication instead of SharedMemory
    * rationale: each shared memory object is a file, using thousands of them floods the open files limit
* also: fixed coroutine set_event_threadsafe is never awaited

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 4 年之前
父节点
当前提交
197666c2b6
共有 4 个文件被更改,包括 193 次插入114 次删除
  1. 1 1
      hivemind/moe/server/task_pool.py
  2. 172 104
      hivemind/utils/mpfuture.py
  3. 5 0
      tests/test_training.py
  4. 15 9
      tests/test_util_modules.py

+ 1 - 1
hivemind/moe/server/task_pool.py

@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(), args)
+        task = Task(MPFuture(synchronize=False), args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)
             task.future.set_exception(exc)

+ 172 - 104
hivemind/utils/mpfuture.py

@@ -2,16 +2,15 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
-from contextlib import nullcontext
+from contextlib import nullcontext, suppress
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
 import os
 import os
 import threading
 import threading
 import uuid
 import uuid
+from weakref import ref
 from enum import Enum, auto
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable
-
-import torch  # used for py3.7-compatible shared memory
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
 
 
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
@@ -34,10 +33,13 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
 
 
-class UpdateType(Enum):
+class MessageType(Enum):
     RESULT = auto()
     RESULT = auto()
     EXCEPTION = auto()
     EXCEPTION = auto()
+    RUNNING = auto()
     CANCEL = auto()
     CANCEL = auto()
+    STATE_REQUEST = auto()
+    STATE_RESPONSE = auto()
 
 
 
 
 class MPFuture(base.Future, Generic[ResultType]):
 class MPFuture(base.Future, Generic[ResultType]):
@@ -46,6 +48,8 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
 
+    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
+      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
@@ -60,49 +64,36 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, MPFuture]] = None  # pending or running futures originated from current process
+    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
+    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
-        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
-        self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
-        self._state_cache: Dict[State, State] = {}
-        # mapping from global to cached local future used that makes updates immediately
-        # available on setter side; dictionary-based cache works because future can visit any state at most once
+    SOFT_UPDATE_TIMEOUT = 0.1  # seconds spent awaiting status update before warning is printed
+    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
 
 
-        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
+    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
+        super().__init__()
+        self.synchronize = synchronize
+        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._state, self._result, self._exception = base.PENDING, None, None
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
         self._use_lock = use_lock
 
 
-        if self._origin_pid != MPFuture._active_pid:
-            with MPFuture._initialization_lock:
-                if self._origin_pid != MPFuture._active_pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    self._initialize_mpfuture_backend()
+        self._initialize_backend_if_necessary()
         assert self._uid not in MPFuture._active_futures
         assert self._uid not in MPFuture._active_futures
-        MPFuture._active_futures[self._uid] = self
-        self._sender_pipe = MPFuture._global_sender_pipe
+        MPFuture._active_futures[self._uid] = ref(self)
+        self._sender_pipe = MPFuture._process_wide_pipe
 
 
         try:
         try:
-            self._loop = loop or asyncio.get_event_loop()
+            self._loop = asyncio.get_event_loop()
             self._aio_event = asyncio.Event()
             self._aio_event = asyncio.Event()
         except RuntimeError:
         except RuntimeError:
             self._loop, self._aio_event = None, None
             self._loop, self._aio_event = None, None
 
 
-    @property
-    def _state(self) -> State:
-        shared_state = ALL_STATES[self._shared_state_code.item()]
-        return self._state_cache.get(shared_state, shared_state)
-
-    @_state.setter
-    def _state(self, new_state: State):
-        self._shared_state_code[...] = ALL_STATES.index(new_state)
-        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
-            self._set_event_threadsafe()
-
-    def _set_event_threadsafe(self):
+    def _set_event_if_necessary(self):
+        if self._aio_event is None or self._aio_event.is_set():
+            return
         try:
         try:
             loop = asyncio.get_running_loop()
             loop = asyncio.get_running_loop()
         except RuntimeError:
         except RuntimeError:
@@ -111,120 +102,197 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
         async def _event_setter():
             self._aio_event.set()
             self._aio_event.set()
 
 
-        if loop == self.get_loop():
+        if self._loop.is_running() and loop == self.get_loop():
             asyncio.create_task(_event_setter())
             asyncio.create_task(_event_setter())
-        else:
+        elif self._loop.is_running() and loop != self.get_loop():
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
+        else:
+            self._loop.run_until_complete(_event_setter())
 
 
     @classmethod
     @classmethod
-    def _initialize_mpfuture_backend(cls):
+    def _initialize_backend_if_necessary(cls):
         pid = os.getpid()
         pid = os.getpid()
-        logger.debug(f"Initializing MPFuture backend for pid {pid}")
-        assert pid != cls._active_pid, "already initialized"
-
-        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
-        cls._active_pid, cls._active_futures = pid, {}
-        cls._pipe_waiter_thread = threading.Thread(
-            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
-        )
-        cls._pipe_waiter_thread.start()
+        if MPFuture._active_pid != pid:
+            with MPFuture._initialization_lock:
+                if MPFuture._active_pid != pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
+                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
+                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
+                    cls._pipe_waiter_thread = threading.Thread(
+                        target=cls._process_updates_in_background,
+                        args=[receiver_pipe],
+                        name=f"{__name__}.BACKEND",
+                        daemon=True,
+                    )
+                    cls._pipe_waiter_thread.start()
 
 
     @classmethod
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         pid = os.getpid()
         while True:
         while True:
             try:
             try:
-                uid, update_type, payload = receiver_pipe.recv()
-                if uid not in cls._active_futures:
-                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
-                elif update_type == UpdateType.RESULT:
-                    cls._active_futures.pop(uid).set_result(payload)
-                elif update_type == UpdateType.EXCEPTION:
-                    cls._active_futures.pop(uid).set_exception(payload)
-                elif update_type == UpdateType.CANCEL:
-                    cls._active_futures.pop(uid).cancel()
+                uid, msg_type, payload = receiver_pipe.recv()
+                future = None
+                future_ref = cls._active_futures.get(uid)
+                if future_ref is not None:
+                    future = future_ref()
+
+                if msg_type == MessageType.STATE_REQUEST:
+                    future_state = None if future is None else future.__getstate__()
+                    use_lock, return_pipe = payload
+                    with MPFuture._update_lock if use_lock else nullcontext():
+                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
+
+                elif msg_type == MessageType.STATE_RESPONSE:
+                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
+                    if future is None:
+                        logger.debug("Received a state update for a future that does not await status update.")
+                    else:
+                        if payload is not None:
+                            future.__setstate__(payload)
+                        else:
+                            base.Future.cancel(future)
+                        state_updated_event.set()
+
+                elif future is None:
+                    logger.debug(
+                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
+                    )
+                elif msg_type == MessageType.RESULT:
+                    future.set_result(payload)
+                elif msg_type == MessageType.EXCEPTION:
+                    future.set_exception(payload)
+                elif msg_type == MessageType.RUNNING:
+                    try:
+                        future.set_running_or_notify_cancel()
+                    except (InvalidStateError, RuntimeError) as e:
+                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
+                elif msg_type == MessageType.CANCEL:
+                    future.cancel()
                 else:
                 else:
-                    raise RuntimeError(f"Received unexpected update type {update_type}")
-            except (BrokenPipeError, EOFError):
+                    raise RuntimeError(f"Received unexpected update type {msg_type}")
+
+                if future is None or future.done():
+                    cls._active_futures.pop(uid, None)
+
+            except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
 
-    def _send_update(self, update_type: UpdateType, payload: Any = None):
+    def _send_update(self, update_type: MessageType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
         """This method sends result, exception or cancel to the MPFuture origin."""
-        with MPFuture._update_lock if self._use_lock else nullcontext():
-            self._sender_pipe.send((self._uid, update_type, payload))
+        try:
+            with MPFuture._update_lock if self._use_lock else nullcontext():
+                self._sender_pipe.send((self._uid, update_type, payload))
+        except (ConnectionError, BrokenPipeError, EOFError) as e:
+            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
+
+    def _synchronize_if_necessary(self):
+        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
+            return
+
+        self._initialize_backend_if_necessary()
+
+        status_updated = threading.Event()
+        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
+        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
+
+        if existing_status_event != status_updated:
+            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
+            return
+
+        # otherwise create a new request for synchronization
+
+        try:
+            with MPFuture._update_lock if self._use_lock else nullcontext():
+                payload = (self._use_lock, self._process_wide_pipe)
+                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
+            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
+            if not status_updated.is_set():
+                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
+                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
+                if not status_updated.is_set() and not self.cancel():
+                    with suppress(InvalidStateError, RuntimeError):
+                        self.set_exception(
+                            TimeoutError(
+                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
+                                f"MPFuture is cancelled"
+                            )
+                        )
+                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
+        except (ConnectionError, BrokenPipeError, EOFError) as e:
+            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
+            if not self.cancel():
+                with suppress(InvalidStateError, RuntimeError):
+                    self.set_exception(e)
+        finally:
+            self._status_requests.pop(self._uid, None)
 
 
     def set_result(self, result: ResultType):
     def set_result(self, result: ResultType):
-        if os.getpid() == self._origin_pid:
-            super().set_result(result)
-            MPFuture._active_futures.pop(self._uid, None)
-        elif self._state in TERMINAL_STATES:
+        if self._state in TERMINAL_STATES:
             raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
             raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
         else:
-            self._state_cache[self._state], self._result = base.FINISHED, result
-            self._send_update(UpdateType.RESULT, result)
+            self._send_update(MessageType.RESULT, result)
+        super().set_result(result)
 
 
     def set_exception(self, exception: Optional[BaseException]):
     def set_exception(self, exception: Optional[BaseException]):
-        if os.getpid() == self._origin_pid:
-            super().set_exception(exception)
-            MPFuture._active_futures.pop(self._uid, None)
-        elif self._state in TERMINAL_STATES:
+        if self._state in TERMINAL_STATES:
             raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
             raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
         else:
-            self._state_cache[self._state], self._exception = base.FINISHED, exception
-            self._send_update(UpdateType.EXCEPTION, exception)
+            self._send_update(MessageType.EXCEPTION, exception)
+        super().set_exception(exception)
 
 
     def cancel(self) -> bool:
     def cancel(self) -> bool:
-        if os.getpid() == self._origin_pid:
-            MPFuture._active_futures.pop(self._uid, None)
-            return super().cancel()
-        elif self._state in [base.RUNNING, base.FINISHED]:
+        if self._state in [base.RUNNING, base.FINISHED]:
             return False
             return False
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
         else:
-            self._state_cache[self._state] = base.CANCELLED
-            self._send_update(UpdateType.CANCEL)
-            return True
+            self._send_update(MessageType.CANCEL)
+        return super().cancel()
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        if self._state == base.PENDING:
-            self._state = base.RUNNING
-            return True
-        elif self._state == base.CANCELLED:
-            return False
-        else:
-            raise InvalidStateError(
-                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
-            )
+        """if synchronize is set to False, this future will ignore any state changes from origin"""
+        self._synchronize_if_necessary()
+        try:
+            is_running = super().set_running_or_notify_cancel()
+            if is_running and os.getpid() != self._origin_pid:
+                self._send_update(MessageType.RUNNING)
+            return is_running
+        except RuntimeError as e:
+            raise InvalidStateError(str(e))
 
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
                 raise RuntimeError("Only the process that created MPFuture can await result")
-            return super().result(timeout)
-        elif self._state == base.CANCELLED:
-            raise base.CancelledError()
-        elif self._exception:
-            raise self._exception
-        else:
-            return self._result
+        return super().result(timeout)
 
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-            return super().exception(timeout)
-        elif self._state == base.CANCELLED:
-            raise base.CancelledError()
-        return self._exception
+        return super().exception(timeout)
 
 
     def done(self) -> bool:
     def done(self) -> bool:
+        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
         return self._state in TERMINAL_STATES
 
 
     def running(self):
     def running(self):
+        self._synchronize_if_necessary()
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
+        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -240,7 +308,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
         yield from self._aio_event.wait().__await__()
         yield from self._aio_event.wait().__await__()
         try:
         try:
-            return super().result(timeout=0)
+            return super().result()
         except base.CancelledError:
         except base.CancelledError:
             raise asyncio.CancelledError()
             raise asyncio.CancelledError()
 
 
@@ -252,8 +320,9 @@ class MPFuture(base.Future, Generic[ResultType]):
 
 
     def __getstate__(self):
     def __getstate__(self):
         return dict(
         return dict(
+            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
             _sender_pipe=self._sender_pipe,
-            _shared_state_code=self._shared_state_code,
+            _state=self._state,
             _origin_pid=self._origin_pid,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _uid=self._uid,
             _use_lock=self._use_lock,
             _use_lock=self._use_lock,
@@ -262,13 +331,12 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
         )
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
+        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
         self._sender_pipe = state["_sender_pipe"]
-        self._shared_state_code = state["_shared_state_code"]
-        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
+        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
         self._use_lock = state["_use_lock"]
 
 
         self._waiters, self._done_callbacks = [], []
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
         self._aio_event, self._loop = None, None
-        self._state_cache = {}

+ 5 - 0
tests/test_training.py

@@ -47,6 +47,8 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
 def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
 def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
     dataset = load_digits(n_class=2)
     dataset = load_digits(n_class=2)
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
+    subsample_ix = torch.randint(0, len(X_train), (32,))
+    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
     SGD = partial(torch.optim.SGD, lr=0.05)
     SGD = partial(torch.optim.SGD, lr=0.05)
 
 
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
@@ -97,6 +99,9 @@ class SwitchNetwork(nn.Module):
 def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
 def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
     dataset = load_digits(n_class=2)
     dataset = load_digits(n_class=2)
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
+    subsample_ix = torch.randint(0, len(X_train), (32,))
+    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
+
     SGD = partial(torch.optim.SGD, lr=0.05)
     SGD = partial(torch.optim.SGD, lr=0.05)
 
 
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]

+ 15 - 9
tests/test_util_modules.py

@@ -221,7 +221,7 @@ def test_mpfuture_bidirectional():
 @pytest.mark.forked
 @pytest.mark.forked
 def test_mpfuture_done_callback():
 def test_mpfuture_done_callback():
     receiver, sender = mp.Pipe(duplex=False)
     receiver, sender = mp.Pipe(duplex=False)
-    events = [mp.Event() for _ in range(5)]
+    events = [mp.Event() for _ in range(6)]
 
 
     def _future_creator():
     def _future_creator():
         future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
         future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
@@ -243,6 +243,7 @@ def test_mpfuture_done_callback():
         future1.add_done_callback(
         future1.add_done_callback(
             lambda future: events[4].set()
             lambda future: events[4].set()
         )  # schedule callback after future1 is already finished
         )  # schedule callback after future1 is already finished
+        events[5].wait()
 
 
     p = mp.Process(target=_future_creator)
     p = mp.Process(target=_future_creator)
     p.start()
     p.start()
@@ -253,24 +254,27 @@ def test_mpfuture_done_callback():
     with pytest.raises(RuntimeError):
     with pytest.raises(RuntimeError):
         future1.add_done_callback(lambda future: (1, 2, 3))
         future1.add_done_callback(lambda future: (1, 2, 3))
 
 
-    p.join()
-    events[0].wait(1)
-    events[1].wait(1)
     assert future1.done() and not future1.cancelled()
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
     assert future2.done() and future2.cancelled()
+    events[0].wait(1)
+    events[1].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
     assert not events[3].is_set()
 
 
+    events[5].set()
+    p.join()
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_many_futures():
+@pytest.mark.parametrize("synchronize", [True, False])
+def test_many_futures(synchronize: bool):
     evt = mp.Event()
     evt = mp.Event()
     receiver, sender = mp.Pipe()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture() for _ in range(1000)]
+    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
     assert len(hivemind.MPFuture._active_futures) == 1000
 
 
     def _run_peer():
     def _run_peer():
-        fork_futures = [hivemind.MPFuture() for _ in range(500)]
+        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
         assert len(hivemind.MPFuture._active_futures) == 500
 
 
         for i, future in enumerate(random.sample(main_futures, 300)):
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -287,13 +291,16 @@ def test_many_futures():
 
 
         assert len(hivemind.MPFuture._active_futures) == 200
         assert len(hivemind.MPFuture._active_futures) == 200
         for future in fork_futures:
         for future in fork_futures:
-            future.cancel()
+            if not future.done():
+                future.set_result(123)
         assert len(hivemind.MPFuture._active_futures) == 0
         assert len(hivemind.MPFuture._active_futures) == 0
 
 
     p = mp.Process(target=_run_peer)
     p = mp.Process(target=_run_peer)
     p.start()
     p.start()
 
 
     some_fork_futures = receiver.recv()
     some_fork_futures = receiver.recv()
+
+    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
     assert len(hivemind.MPFuture._active_futures) == 700
 
 
     for future in some_fork_futures:
     for future in some_fork_futures:
@@ -301,7 +308,6 @@ def test_many_futures():
     for future in random.sample(some_fork_futures, 200):
     for future in random.sample(some_fork_futures, 200):
         future.set_result(321)
         future.set_result(321)
 
 
-    time.sleep(0.5)
     evt.set()
     evt.set()
     for future in main_futures:
     for future in main_futures:
         future.cancel()
         future.cancel()