ソースを参照

Remove shared memory from MPFuture, fix minor bugs (#317)

* re-written MPFuture to use pipe-only communication instead of SharedMemory
    * rationale: each shared memory object is a file, using thousands of them floods the open files limit
* also: fixed coroutine set_event_threadsafe is never awaited

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
justheuristic 4 年 前
コミット
197666c2b6
4 ファイル変更193 行追加114 行削除
  1. 1 1
      hivemind/moe/server/task_pool.py
  2. 172 104
      hivemind/utils/mpfuture.py
  3. 5 0
      tests/test_training.py
  4. 15 9
      tests/test_util_modules.py

+ 1 - 1
hivemind/moe/server/task_pool.py

@@ -100,7 +100,7 @@ class TaskPool(TaskPoolBase):
 
     def submit_task(self, *args: torch.Tensor) -> Future:
         """Add task to this pool's queue, return Future for its output"""
-        task = Task(MPFuture(), args)
+        task = Task(MPFuture(synchronize=False), args)
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             task.future.set_exception(exc)

+ 172 - 104
hivemind/utils/mpfuture.py

@@ -2,16 +2,15 @@ from __future__ import annotations
 
 import asyncio
 import concurrent.futures._base as base
-from contextlib import nullcontext
+from contextlib import nullcontext, suppress
 import multiprocessing as mp
 import multiprocessing.connection
 import os
 import threading
 import uuid
+from weakref import ref
 from enum import Enum, auto
-from typing import Generic, TypeVar, Dict, Optional, Any, Callable
-
-import torch  # used for py3.7-compatible shared memory
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable, Type, Tuple
 
 from hivemind.utils.logging import get_logger
 
@@ -34,10 +33,13 @@ except ImportError:
         """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
-class UpdateType(Enum):
+class MessageType(Enum):
     RESULT = auto()
     EXCEPTION = auto()
+    RUNNING = auto()
     CANCEL = auto()
+    STATE_REQUEST = auto()
+    STATE_RESPONSE = auto()
 
 
 class MPFuture(base.Future, Generic[ResultType]):
@@ -46,6 +48,8 @@ class MPFuture(base.Future, Generic[ResultType]):
     Any process can access future status and set the result / exception and check for state.
     However, only the original process (i.e. the process that created the future) can await the result or exception.
 
+    :param synchronize: if True (default), future will request state from origin, otherwise it will only use local state
+      Setting synchronize=False results in slightly better performance of done or set_running_or_notify_cancel
     :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
       If set to False, writing to this future ignores global lock, slightly improving performance, but making user
       responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
@@ -60,49 +64,36 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
     _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
-    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _process_wide_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
     _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
-    _active_futures: Optional[Dict[UID, MPFuture]] = None  # pending or running futures originated from current process
+    _active_futures: Optional[Dict[UID, Type[ref][MPFuture]]] = None  # non-done futures originated from this process
+    _status_requests: Optional[Dict[UID, Tuple[MPFuture, threading.Event]]] = None  # futures to be updated by origin
     _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
-    def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
-        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
-        self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
-        self._state_cache: Dict[State, State] = {}
-        # mapping from global to cached local future used that makes updates immediately
-        # available on setter side; dictionary-based cache works because future can visit any state at most once
+    SOFT_UPDATE_TIMEOUT = 0.1  # seconds spent awaiting status update before warning is printed
+    HARD_UPDATE_TIMEOUT = 10.0  # seconds spent awaiting status update before future is automatically cancelled
 
-        base.Future.__init__(self)  # parent init is deferred because it uses self._shared_state_code
+    def __init__(self, *, synchronize: bool = True, use_lock: bool = True):
+        super().__init__()
+        self.synchronize = synchronize
+        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
         self._state, self._result, self._exception = base.PENDING, None, None
         self._use_lock = use_lock
 
-        if self._origin_pid != MPFuture._active_pid:
-            with MPFuture._initialization_lock:
-                if self._origin_pid != MPFuture._active_pid:
-                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
-                    self._initialize_mpfuture_backend()
+        self._initialize_backend_if_necessary()
         assert self._uid not in MPFuture._active_futures
-        MPFuture._active_futures[self._uid] = self
-        self._sender_pipe = MPFuture._global_sender_pipe
+        MPFuture._active_futures[self._uid] = ref(self)
+        self._sender_pipe = MPFuture._process_wide_pipe
 
         try:
-            self._loop = loop or asyncio.get_event_loop()
+            self._loop = asyncio.get_event_loop()
             self._aio_event = asyncio.Event()
         except RuntimeError:
             self._loop, self._aio_event = None, None
 
-    @property
-    def _state(self) -> State:
-        shared_state = ALL_STATES[self._shared_state_code.item()]
-        return self._state_cache.get(shared_state, shared_state)
-
-    @_state.setter
-    def _state(self, new_state: State):
-        self._shared_state_code[...] = ALL_STATES.index(new_state)
-        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
-            self._set_event_threadsafe()
-
-    def _set_event_threadsafe(self):
+    def _set_event_if_necessary(self):
+        if self._aio_event is None or self._aio_event.is_set():
+            return
         try:
             loop = asyncio.get_running_loop()
         except RuntimeError:
@@ -111,120 +102,197 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
             self._aio_event.set()
 
-        if loop == self.get_loop():
+        if self._loop.is_running() and loop == self.get_loop():
             asyncio.create_task(_event_setter())
-        else:
+        elif self._loop.is_running() and loop != self.get_loop():
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
+        else:
+            self._loop.run_until_complete(_event_setter())
 
     @classmethod
-    def _initialize_mpfuture_backend(cls):
+    def _initialize_backend_if_necessary(cls):
         pid = os.getpid()
-        logger.debug(f"Initializing MPFuture backend for pid {pid}")
-        assert pid != cls._active_pid, "already initialized"
-
-        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
-        cls._active_pid, cls._active_futures = pid, {}
-        cls._pipe_waiter_thread = threading.Thread(
-            target=cls._process_updates_in_background, args=[receiver_pipe], name=f"{__name__}.BACKEND", daemon=True
-        )
-        cls._pipe_waiter_thread.start()
+        if MPFuture._active_pid != pid:
+            with MPFuture._initialization_lock:
+                if MPFuture._active_pid != pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    logger.debug(f"Initializing MPFuture backend for pid {pid}")
+                    receiver_pipe, cls._process_wide_pipe = mp.Pipe(duplex=False)
+                    cls._active_pid, cls._active_futures, cls._status_requests = pid, {}, {}
+                    cls._pipe_waiter_thread = threading.Thread(
+                        target=cls._process_updates_in_background,
+                        args=[receiver_pipe],
+                        name=f"{__name__}.BACKEND",
+                        daemon=True,
+                    )
+                    cls._pipe_waiter_thread.start()
 
     @classmethod
     def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
         pid = os.getpid()
         while True:
             try:
-                uid, update_type, payload = receiver_pipe.recv()
-                if uid not in cls._active_futures:
-                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
-                elif update_type == UpdateType.RESULT:
-                    cls._active_futures.pop(uid).set_result(payload)
-                elif update_type == UpdateType.EXCEPTION:
-                    cls._active_futures.pop(uid).set_exception(payload)
-                elif update_type == UpdateType.CANCEL:
-                    cls._active_futures.pop(uid).cancel()
+                uid, msg_type, payload = receiver_pipe.recv()
+                future = None
+                future_ref = cls._active_futures.get(uid)
+                if future_ref is not None:
+                    future = future_ref()
+
+                if msg_type == MessageType.STATE_REQUEST:
+                    future_state = None if future is None else future.__getstate__()
+                    use_lock, return_pipe = payload
+                    with MPFuture._update_lock if use_lock else nullcontext():
+                        return_pipe.send((uid, MessageType.STATE_RESPONSE, future_state))
+
+                elif msg_type == MessageType.STATE_RESPONSE:
+                    future, state_updated_event = cls._status_requests.get(uid, (None, None))
+                    if future is None:
+                        logger.debug("Received a state update for a future that does not await status update.")
+                    else:
+                        if payload is not None:
+                            future.__setstate__(payload)
+                        else:
+                            base.Future.cancel(future)
+                        state_updated_event.set()
+
+                elif future is None:
+                    logger.debug(
+                        f"Received {msg_type} for MPFuture uid={uid}, but future is already done or destroyed"
+                    )
+                elif msg_type == MessageType.RESULT:
+                    future.set_result(payload)
+                elif msg_type == MessageType.EXCEPTION:
+                    future.set_exception(payload)
+                elif msg_type == MessageType.RUNNING:
+                    try:
+                        future.set_running_or_notify_cancel()
+                    except (InvalidStateError, RuntimeError) as e:
+                        logger.debug(f"Could not set MPFuture (uid={uid}) to running due to {e}", exc_info=True)
+                elif msg_type == MessageType.CANCEL:
+                    future.cancel()
                 else:
-                    raise RuntimeError(f"Received unexpected update type {update_type}")
-            except (BrokenPipeError, EOFError):
+                    raise RuntimeError(f"Received unexpected update type {msg_type}")
+
+                if future is None or future.done():
+                    cls._active_futures.pop(uid, None)
+
+            except (BrokenPipeError, EOFError, ConnectionError):
                 logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
             except Exception as e:
                 logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
 
-    def _send_update(self, update_type: UpdateType, payload: Any = None):
+    def _send_update(self, update_type: MessageType, payload: Any = None):
         """This method sends result, exception or cancel to the MPFuture origin."""
-        with MPFuture._update_lock if self._use_lock else nullcontext():
-            self._sender_pipe.send((self._uid, update_type, payload))
+        try:
+            with MPFuture._update_lock if self._use_lock else nullcontext():
+                self._sender_pipe.send((self._uid, update_type, payload))
+        except (ConnectionError, BrokenPipeError, EOFError) as e:
+            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
+
+    def _synchronize_if_necessary(self):
+        if not self.synchronize or os.getpid() == self._origin_pid or self._state in TERMINAL_STATES:
+            return
+
+        self._initialize_backend_if_necessary()
+
+        status_updated = threading.Event()
+        _, existing_status_event = self._status_requests.setdefault(self._uid, (self, status_updated))
+        # this line checks if another thread is synchronizing concurrently, assuming that setdefault to be atomic
+
+        if existing_status_event != status_updated:
+            existing_status_event.wait(MPFuture.HARD_UPDATE_TIMEOUT)
+            return
+
+        # otherwise create a new request for synchronization
+
+        try:
+            with MPFuture._update_lock if self._use_lock else nullcontext():
+                payload = (self._use_lock, self._process_wide_pipe)
+                self._sender_pipe.send((self._uid, MessageType.STATE_REQUEST, payload))
+            status_updated.wait(MPFuture.SOFT_UPDATE_TIMEOUT)
+            if not status_updated.is_set():
+                logger.warning(f"Status update took over {MPFuture.SOFT_UPDATE_TIMEOUT}, expect performance issues")
+                status_updated.wait(MPFuture.HARD_UPDATE_TIMEOUT - MPFuture.SOFT_UPDATE_TIMEOUT)
+                if not status_updated.is_set() and not self.cancel():
+                    with suppress(InvalidStateError, RuntimeError):
+                        self.set_exception(
+                            TimeoutError(
+                                f"Status update took over {MPFuture.HARD_UPDATE_TIMEOUT} seconds, "
+                                f"MPFuture is cancelled"
+                            )
+                        )
+                    status_updated.set()  # this triggers any concurrent _synchronize_if_necessary calls to finish
+        except (ConnectionError, BrokenPipeError, EOFError) as e:
+            logger.error(f"MPFuture was cancelled because sender pipe is broken. Origin process is probably down.")
+            if not self.cancel():
+                with suppress(InvalidStateError, RuntimeError):
+                    self.set_exception(e)
+        finally:
+            self._status_requests.pop(self._uid, None)
 
     def set_result(self, result: ResultType):
-        if os.getpid() == self._origin_pid:
-            super().set_result(result)
-            MPFuture._active_futures.pop(self._uid, None)
-        elif self._state in TERMINAL_STATES:
+        if self._state in TERMINAL_STATES:
             raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
-            self._state_cache[self._state], self._result = base.FINISHED, result
-            self._send_update(UpdateType.RESULT, result)
+            self._send_update(MessageType.RESULT, result)
+        super().set_result(result)
 
     def set_exception(self, exception: Optional[BaseException]):
-        if os.getpid() == self._origin_pid:
-            super().set_exception(exception)
-            MPFuture._active_futures.pop(self._uid, None)
-        elif self._state in TERMINAL_STATES:
+        if self._state in TERMINAL_STATES:
             raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
-            self._state_cache[self._state], self._exception = base.FINISHED, exception
-            self._send_update(UpdateType.EXCEPTION, exception)
+            self._send_update(MessageType.EXCEPTION, exception)
+        super().set_exception(exception)
 
     def cancel(self) -> bool:
-        if os.getpid() == self._origin_pid:
-            MPFuture._active_futures.pop(self._uid, None)
-            return super().cancel()
-        elif self._state in [base.RUNNING, base.FINISHED]:
+        if self._state in [base.RUNNING, base.FINISHED]:
             return False
+        elif os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            self._set_event_if_necessary()
         else:
-            self._state_cache[self._state] = base.CANCELLED
-            self._send_update(UpdateType.CANCEL)
-            return True
+            self._send_update(MessageType.CANCEL)
+        return super().cancel()
 
     def set_running_or_notify_cancel(self):
-        if self._state == base.PENDING:
-            self._state = base.RUNNING
-            return True
-        elif self._state == base.CANCELLED:
-            return False
-        else:
-            raise InvalidStateError(
-                f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})"
-            )
+        """if synchronize is set to False, this future will ignore any state changes from origin"""
+        self._synchronize_if_necessary()
+        try:
+            is_running = super().set_running_or_notify_cancel()
+            if is_running and os.getpid() != self._origin_pid:
+                self._send_update(MessageType.RUNNING)
+            return is_running
+        except RuntimeError as e:
+            raise InvalidStateError(str(e))
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await result")
-            return super().result(timeout)
-        elif self._state == base.CANCELLED:
-            raise base.CancelledError()
-        elif self._exception:
-            raise self._exception
-        else:
-            return self._result
+        return super().result(timeout)
 
     def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
         if self._state not in TERMINAL_STATES:
             if os.getpid() != self._origin_pid:
                 raise RuntimeError("Only the process that created MPFuture can await exception")
-            return super().exception(timeout)
-        elif self._state == base.CANCELLED:
-            raise base.CancelledError()
-        return self._exception
+        return super().exception(timeout)
 
     def done(self) -> bool:
+        self._synchronize_if_necessary()
         return self._state in TERMINAL_STATES
 
     def running(self):
+        self._synchronize_if_necessary()
         return self._state == base.RUNNING
 
     def cancelled(self):
+        self._synchronize_if_necessary()
         return self._state == base.CANCELLED
 
     def add_done_callback(self, callback: Callable[[MPFuture], None]):
@@ -240,7 +308,7 @@ class MPFuture(base.Future, Generic[ResultType]):
             raise RuntimeError("Can't await: MPFuture was created with no event loop")
         yield from self._aio_event.wait().__await__()
         try:
-            return super().result(timeout=0)
+            return super().result()
         except base.CancelledError:
             raise asyncio.CancelledError()
 
@@ -252,8 +320,9 @@ class MPFuture(base.Future, Generic[ResultType]):
 
     def __getstate__(self):
         return dict(
+            synchronize=self.synchronize,
             _sender_pipe=self._sender_pipe,
-            _shared_state_code=self._shared_state_code,
+            _state=self._state,
             _origin_pid=self._origin_pid,
             _uid=self._uid,
             _use_lock=self._use_lock,
@@ -262,13 +331,12 @@ class MPFuture(base.Future, Generic[ResultType]):
         )
 
     def __setstate__(self, state):
+        self.synchronize = state["synchronize"]
         self._sender_pipe = state["_sender_pipe"]
-        self._shared_state_code = state["_shared_state_code"]
-        self._origin_pid, self._uid = state["_origin_pid"], state["_uid"]
+        self._state, self._origin_pid, self._uid = state["_state"], state["_origin_pid"], state["_uid"]
         self._result, self._exception = state["_result"], state["_exception"]
         self._use_lock = state["_use_lock"]
 
         self._waiters, self._done_callbacks = [], []
         self._condition = threading.Condition()
         self._aio_event, self._loop = None, None
-        self._state_cache = {}

+ 5 - 0
tests/test_training.py

@@ -47,6 +47,8 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
 def test_moe_training(max_steps: int = 100, threshold: float = 0.9, num_experts=2):
     dataset = load_digits(n_class=2)
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
+    subsample_ix = torch.randint(0, len(X_train), (32,))
+    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
     SGD = partial(torch.optim.SGD, lr=0.05)
 
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]
@@ -97,6 +99,9 @@ class SwitchNetwork(nn.Module):
 def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_experts=5):
     dataset = load_digits(n_class=2)
     X_train, y_train = torch.tensor(dataset["data"], dtype=torch.float), torch.tensor(dataset["target"])
+    subsample_ix = torch.randint(0, len(X_train), (32,))
+    X_train, y_train = X_train[subsample_ix], y_train[subsample_ix]
+
     SGD = partial(torch.optim.SGD, lr=0.05)
 
     all_expert_uids = [f"expert.{i}" for i in range(num_experts)]

+ 15 - 9
tests/test_util_modules.py

@@ -221,7 +221,7 @@ def test_mpfuture_bidirectional():
 @pytest.mark.forked
 def test_mpfuture_done_callback():
     receiver, sender = mp.Pipe(duplex=False)
-    events = [mp.Event() for _ in range(5)]
+    events = [mp.Event() for _ in range(6)]
 
     def _future_creator():
         future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
@@ -243,6 +243,7 @@ def test_mpfuture_done_callback():
         future1.add_done_callback(
             lambda future: events[4].set()
         )  # schedule callback after future1 is already finished
+        events[5].wait()
 
     p = mp.Process(target=_future_creator)
     p.start()
@@ -253,24 +254,27 @@ def test_mpfuture_done_callback():
     with pytest.raises(RuntimeError):
         future1.add_done_callback(lambda future: (1, 2, 3))
 
-    p.join()
-    events[0].wait(1)
-    events[1].wait(1)
     assert future1.done() and not future1.cancelled()
     assert future2.done() and future2.cancelled()
+    events[0].wait(1)
+    events[1].wait(1)
     assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
     assert not events[3].is_set()
 
+    events[5].set()
+    p.join()
+
 
 @pytest.mark.forked
-def test_many_futures():
+@pytest.mark.parametrize("synchronize", [True, False])
+def test_many_futures(synchronize: bool):
     evt = mp.Event()
     receiver, sender = mp.Pipe()
-    main_futures = [hivemind.MPFuture() for _ in range(1000)]
+    main_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(1000)]
     assert len(hivemind.MPFuture._active_futures) == 1000
 
     def _run_peer():
-        fork_futures = [hivemind.MPFuture() for _ in range(500)]
+        fork_futures = [hivemind.MPFuture(synchronize=synchronize) for _ in range(500)]
         assert len(hivemind.MPFuture._active_futures) == 500
 
         for i, future in enumerate(random.sample(main_futures, 300)):
@@ -287,13 +291,16 @@ def test_many_futures():
 
         assert len(hivemind.MPFuture._active_futures) == 200
         for future in fork_futures:
-            future.cancel()
+            if not future.done():
+                future.set_result(123)
         assert len(hivemind.MPFuture._active_futures) == 0
 
     p = mp.Process(target=_run_peer)
     p.start()
 
     some_fork_futures = receiver.recv()
+
+    time.sleep(0.5)  # wait for active futures to synchronize
     assert len(hivemind.MPFuture._active_futures) == 700
 
     for future in some_fork_futures:
@@ -301,7 +308,6 @@ def test_many_futures():
     for future in random.sample(some_fork_futures, 200):
         future.set_result(321)
 
-    time.sleep(0.5)
     evt.set()
     for future in main_futures:
         future.cancel()