Переглянути джерело

Refactor MPFuture to use a single pipe/thread per process (#298)

- Removed hivemind.utils.threading.run_in_background and HIVEMIND_THREADS
- Refactored MPFuture to be a single object instead of a linked pair of objects
- MPFuture now uses a single process-wide pipe and thread, instead of spawning new pipe/thread for each future
- MPFuture.result/exception can now only be awaited from the process that created it
- MPFuture now returns the same exception types as regular future (and as asyncio.Future in __await__)
- Added more thorough tests for MPFuture

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 4 роки тому
батько
коміт
200fbecdbf

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -6,7 +6,7 @@ import argparse
 import torch
 
 import hivemind
-from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
+from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
 from hivemind.proto import runtime_pb2
 
 

+ 1 - 1
benchmarks/benchmark_dht.py

@@ -6,7 +6,7 @@ from tqdm import trange
 
 import hivemind
 import hivemind.server.expert_uid
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 
 logger = hivemind.get_logger(__name__)
 

+ 1 - 1
benchmarks/benchmark_throughput.py

@@ -9,7 +9,7 @@ import torch
 import hivemind
 from hivemind import find_open_port
 from hivemind.server import layers
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 
 

+ 2 - 2
examples/albert/README.md

@@ -40,7 +40,7 @@ wandb: Run `wandb offline` to turn off syncing.
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - run:
 ```shell
-HIVEMIND_THREADS=64 python run_trainer.py \
+python run_trainer.py \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
 ```
@@ -88,7 +88,7 @@ Here's an example of a full trainer script for Google Colab:
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
-!ulimit -n 4096 && HIVEMIND_THREADS=256 python ./hivemind/examples/albert/run_trainer.py \
+!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
  --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
  --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \

+ 10 - 10
hivemind/client/averaging/__init__.py

@@ -290,9 +290,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
 
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
+        self._outer_pipe.send(('_step', [], dict(future=future, gather_binary=gather_binary, weight=weight,
                                                  allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
 
@@ -463,8 +463,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
-        future, _future = MPFuture.make_pair()
-        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        future = MPFuture()
+        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', future))
         return await future
 
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
@@ -477,8 +477,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=future)))
         return future.result() if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture):
@@ -537,8 +537,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
         :returns: averager's current group key bits (without prefix)
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get_group_bits', [], dict(future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get_group_bits', [], dict(future=future)))
         return future.result() if wait else future
 
     async def _get_group_bits(self, future: MPFuture):
@@ -549,9 +549,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
         :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
         """
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         assert all(bit in '01' for bit in group_bits)
-        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=future)))
         return future.result() if wait else future
 
     async def _set_group_bits(self, group_bits: str, future: MPFuture):

+ 9 - 7
hivemind/client/averaging/training.py

@@ -1,4 +1,5 @@
 """ An extension of averager that supports common optimization use cases. """
+from concurrent.futures import ThreadPoolExecutor
 from itertools import chain
 from threading import Lock
 from typing import Sequence, Dict, Iterator, Optional
@@ -7,7 +8,7 @@ from contextlib import nullcontext
 import torch
 
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
+from hivemind.utils import nested_flatten, nested_pack, get_logger
 
 logger = get_logger(__name__)
 
@@ -39,6 +40,7 @@ class TrainingAverager(DecentralizedAverager):
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         if initialize_optimizer:
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
@@ -47,15 +49,15 @@ class TrainingAverager(DecentralizedAverager):
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
 
-    @torch.no_grad()
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
-        """ Average optimizer weights and gradients with peers.
+        """
+        Average optimizer weights and gradients with peers.
+
         :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
         modifications occur during averaging step
-        :param wait: if True waits, otherwise returns Future
         """
         if not wait:
-            return run_in_background(self.step, data_lock, wait=True, **kwargs)
+            return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
 
         # if data_lock is supplied, tensors might change during averaging, so we need to copy them
         use_old_local_tensors = data_lock is not None
@@ -63,7 +65,7 @@ class TrainingAverager(DecentralizedAverager):
             data_lock = nullcontext()
 
         local_tensors = list(self.local_tensors())
-        with self.lock_averager_step:
+        with self.lock_averager_step, torch.no_grad():
             # fill averager's tensors with current local tensors
             with data_lock, self.get_tensors() as averaged_tensors:
                 if use_old_local_tensors:
@@ -73,7 +75,7 @@ class TrainingAverager(DecentralizedAverager):
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
 
-            # find a group and hopefully average tensors with peers, scaled by peer's weight
+            # find a group and hopefully average tensors with peers, use batch sizes as weights
             gathered = super().step(**kwargs)
             if gathered is not None:
                 # load averaged tensors back into model

+ 8 - 8
hivemind/dht/__init__.py

@@ -127,8 +127,8 @@ class DHT(mp.Process):
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :returns: (value, expiration time); if value was not found, returns None
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
 
     async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
@@ -153,9 +153,9 @@ class DHT(mp.Process):
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
-                                                  future=_future, **kwargs)))
+                                                  future=future, **kwargs)))
         return future if return_future else future.result()
 
     async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
@@ -184,8 +184,8 @@ class DHT(mp.Process):
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
 
     async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
@@ -226,8 +226,8 @@ class DHT(mp.Process):
         """
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=future)))
         return future.result()
 
     async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -6,7 +6,7 @@ import torch
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.server.layers import schedule_name_to_scheduler
 

+ 7 - 7
hivemind/server/task_pool.py

@@ -14,7 +14,8 @@ from typing import List, Tuple, Dict, Any, Generator
 
 import torch
 
-from hivemind.utils import MPFuture, get_logger, FutureStateError
+from hivemind.utils import get_logger
+from hivemind.utils.mpfuture import MPFuture, InvalidStateError
 
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
@@ -89,15 +90,14 @@ class TaskPool(TaskPoolBase):
 
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
-        future1, future2 = MPFuture.make_pair()
-        task = Task(future1, args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
-            future2.set_exception(exc)
+            task.future.set_exception(exc)
         else:
             self.tasks.put(task)
             self.undispatched_task_timestamps.put(time.time())
-        return future2
+        return task.future
 
     def iterate_minibatches(self, *args, **kwargs):
         """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
@@ -127,7 +127,7 @@ class TaskPool(TaskPoolBase):
                 if task.future.set_running_or_notify_cancel():
                     batch.append(task)
                     total_size += task_size
-            except FutureStateError as e:
+            except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
 
     def run(self, *args, **kwargs):
@@ -196,7 +196,7 @@ class TaskPool(TaskPoolBase):
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
                 try:
                     task.future.set_result(tuple(task_outputs))
-                except FutureStateError as e:
+                except InvalidStateError as e:
                     logger.debug(f"Failed to send task result due to an exception: {e}")
 
     @property

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import *
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.serializer import *
 from hivemind.utils.tensor_descr import *
-from hivemind.utils.threading import *
 from hivemind.utils.timed_storage import *

+ 6 - 3
hivemind/utils/compression.py

@@ -1,3 +1,5 @@
+import os
+from concurrent.futures import ThreadPoolExecutor
 from typing import Tuple, Sequence, Optional
 
 import numpy as np
@@ -6,7 +8,7 @@ import warnings
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.threading import run_in_background
+
 
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
@@ -17,6 +19,8 @@ UNIFORM_BUCKETS_STD_RANGE = 6
 FP16_MAX = 65_504
 UINT8_RANGE = 256
 
+COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
+
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 
 
@@ -48,8 +52,7 @@ def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size
     jobs = []
     for i in range(num_chunks):
         chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(run_in_background(
-            np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
 
     for job in jobs:
         job.result()

+ 0 - 14
hivemind/utils/threading.py → hivemind/utils/limits.py

@@ -1,21 +1,7 @@
-import os
-from concurrent.futures import Future, ThreadPoolExecutor
-
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
-EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
-
-
-def run_in_background(func: callable, *args, **kwargs) -> Future:
-    """ run func(*args, **kwargs) in background and return Future for its outputs """
-    global EXECUTOR_PID, GLOBAL_EXECUTOR
-    if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("HIVEMIND_THREADS", 128)))
-        EXECUTOR_PID = os.getpid()
-    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
-
 
 def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
     """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """

+ 212 - 121
hivemind/utils/mpfuture.py

@@ -2,171 +2,262 @@ from __future__ import annotations
 
 import asyncio
 import concurrent.futures._base as base
+from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing.connection
-import time
-from functools import lru_cache
-from typing import Optional, Tuple, Generic, TypeVar
+import os
+import threading
+import uuid
+from enum import Enum, auto
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable
 
-from hivemind.utils.threading import run_in_background
+import torch    # used for py3.7-compatible shared memory
 
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+# flavour types
 ResultType = TypeVar('ResultType')
+PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
+ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
+TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
 
+try:
+    from concurrent.futures import InvalidStateError
+except ImportError:
+    # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
+    # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
+    # defining and raising this error if necessary.
+    class InvalidStateError(Exception):
+        """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
-class FutureStateError(RuntimeError):
-    """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
-    pass
+
+class UpdateType(Enum):
+    RESULT = auto()
+    EXCEPTION = auto()
+    CANCEL = auto()
 
 
 class MPFuture(base.Future, Generic[ResultType]):
-    """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
+    """
+    A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
+    Any process can access future status and set the result / exception and check for state.
+    However, only the original process (i.e. the process that created the future) can await the result or exception.
+
+    :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
+      If set to False, writing to this future ignores global lock, slightly improving performance, but making user
+      responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
+    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
+
+    :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
+     More specifically, there are two known limitations:
+       - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
+       - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
+         and only the origin process can call result/exception/cancel.
+    """
+    _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
+    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
+    _active_futures: Optional[Dict[UID, MPFuture]] = None  # pending or running futures originated from current process
+    _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
-    TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
+    def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
+        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
+        self._state_cache:  Dict[State, State] = {}  # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
 
-    def __init__(self, connection: mp.connection.Connection):
-        """ manually create MPFuture. Please use MPFuture.make_pair instead """
+        base.Future.__init__(self)   # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
-        self.connection = connection
+        self._use_lock = use_lock
 
-    @classmethod
-    def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
-        """ Create a pair of linked futures to be used in two processes """
-        connection1, connection2 = mp.Pipe()
-        return cls(connection1), cls(connection2)
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
+        assert self._uid not in MPFuture._active_futures
+        MPFuture._active_futures[self._uid] = self
+        self._sender_pipe = MPFuture._global_sender_pipe
 
-    def _send_updates(self):
-        """ Send updates to a paired MPFuture """
         try:
-            self.connection.send((self._state, self._result, self._exception))
-            if self._state in self.TERMINAL_STATES:
-                self._shutdown_trigger.set_result(True)
-                self.connection.close()
-            return True
-        except BrokenPipeError:
-            return False
+            self._loop = loop or asyncio.get_event_loop()
+            self._aio_event = asyncio.Event()
+        except RuntimeError:
+            self._loop, self._aio_event = None, None
 
-    def _recv_updates(self, timeout: Optional[float]):
-        """ Await updates from a paired MPFuture """
-        try:
-            future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
-                               return_when=base.FIRST_COMPLETED)[0].pop()
-            if future is self._shutdown_trigger:
-                raise BrokenPipeError()
-            if not future.result():
-                raise TimeoutError()
-            self._state, result, exception = self.connection.recv()
-            self._result = result if result is not None else self._result
-            self._exception = exception if exception is not None else self._exception
-            if self._state in self.TERMINAL_STATES:
-                self.connection.close()
-        except TimeoutError as e:
-            raise e
-        except (BrokenPipeError, OSError, EOFError) as e:
-            if self._state in (base.PENDING, base.RUNNING):
-                self._state, self._exception = base.FINISHED, e
-
-    def _await_terminal_state(self, timeout: Optional[float]):
-        """ Await updates until future is either finished, cancelled or got an exception """
-        time_left = float('inf') if timeout is None else timeout
-        time_before = time.monotonic()
-        while self._state not in self.TERMINAL_STATES and time_left > 0:
-            self._recv_updates(time_left if timeout else None)
-            time_spent = time.monotonic() - time_before
-            time_left, time_before = time_left - time_spent, time_before + time_spent
-
-    def _sync_updates(self):
-        """ Apply queued updates from a paired MPFuture without waiting for new ones """
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
-            self._recv_updates(timeout=0)
-        except TimeoutError:
-            pass
+            loop = asyncio.get_running_loop()
+        except RuntimeError:
+            loop = None
+
+        async def _event_setter():
+            self._aio_event.set()
+
+        if loop == self.get_loop():
+            asyncio.create_task(_event_setter())
+        else:
+            asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
+
+    @classmethod
+    def _initialize_mpfuture_backend(cls):
+        pid = os.getpid()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
+        assert pid != cls._active_pid, "already initialized"
+
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(target=cls._process_updates_in_background, args=[receiver_pipe],
+                                                   name=f'{__name__}.BACKEND', daemon=True)
+        cls._pipe_waiter_thread.start()
+
+    @classmethod
+    def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
+        pid = os.getpid()
+        while True:
+            try:
+                uid, update_type, payload = receiver_pipe.recv()
+                if uid not in cls._active_futures:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
+                    cls._active_futures.pop(uid).set_result(payload)
+                elif update_type == UpdateType.EXCEPTION:
+                    cls._active_futures.pop(uid).set_exception(payload)
+                elif update_type == UpdateType.CANCEL:
+                    cls._active_futures.pop(uid).cancel()
+                else:
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
+            except (BrokenPipeError, EOFError):
+                logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
+            except Exception as e:
+                logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
+
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
+        """ This method sends result, exception or cancel to the MPFuture origin. """
+        with MPFuture._update_lock if self._use_lock else nullcontext():
+            self._sender_pipe.send((self._uid, update_type, payload))
 
     def set_result(self, result: ResultType):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})")
-        self._state, self._result = base.FINISHED, result
-        return self._send_updates()
-
-    def set_exception(self, exception: BaseException):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})")
-        self._state, self._exception = base.FINISHED, exception
-        self._send_updates()
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
+            MPFuture._active_futures.pop(self._uid, None)
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
+        else:
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
+
+    def set_exception(self, exception: Optional[BaseException]):
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
+            MPFuture._active_futures.pop(self._uid, None)
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
+        else:
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
+
+    def cancel(self) -> bool:
+        if os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
+        else:
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
     def set_running_or_notify_cancel(self):
-        self._sync_updates()
         if self._state == base.PENDING:
             self._state = base.RUNNING
-            return self._send_updates()
+            return True
         elif self._state == base.CANCELLED:
             return False
         else:
-            raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})")
-
-    def cancel(self):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            return False
-        self._state, self._exception = base.CANCELLED, base.CancelledError()
-        return self._send_updates()
+            raise InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})")
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
-        self._await_terminal_state(timeout)
-        if self._exception is not None:
+        if self._state not in TERMINAL_STATES:
+            if os.getpid() != self._origin_pid:
+                raise RuntimeError("Only the process that created MPFuture can await result")
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
             raise self._exception
-        return self._result
+        else:
+            return self._result
 
-    def exception(self, timeout=None) -> BaseException:
-        self._await_terminal_state(timeout)
-        if self._state == base.CANCELLED:
+    def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
+        if self._state not in TERMINAL_STATES:
+            if os.getpid() != self._origin_pid:
+                raise RuntimeError("Only the process that created MPFuture can await exception")
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
             raise base.CancelledError()
         return self._exception
 
     def done(self) -> bool:
-        self._sync_updates()
-        return self._state in self.TERMINAL_STATES
+        return self._state in TERMINAL_STATES
 
     def running(self):
-        self._sync_updates()
         return self._state == base.RUNNING
 
     def cancelled(self):
-        self._sync_updates()
         return self._state == base.CANCELLED
 
-    def add_done_callback(self, callback):
-        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
-
-    def remove_done_callback(self, callback):
-        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
+    def add_done_callback(self, callback: Callable[[MPFuture], None]):
+        if os.getpid() != self._origin_pid:
+            raise RuntimeError("Only the process that created MPFuture can set callbacks")
+        return super().add_done_callback(callback)
 
-    def get_loop(self):
-        raise NotImplementedError(f"MPFuture doesn't support get_loop")
-
-    @property
-    @lru_cache()
-    def _shutdown_trigger(self):
-        return base.Future()
-
-    def __repr__(self):
-        self._sync_updates()
-        if self._state == base.FINISHED:
-            if self._exception:
-                return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
-            else:
-                return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
-        else:
-            return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
+    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
+        return self._loop
 
     def __await__(self):
-        yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
-        if self._exception:
-            raise self._exception
-        return self._result
+        if not self._aio_event:
+            raise RuntimeError("Can't await: MPFuture was created with no event loop")
+        yield from self._aio_event.wait().__await__()
+        try:
+            return super().result(timeout=0)
+        except base.CancelledError:
+            raise asyncio.CancelledError()
 
     def __del__(self):
-        self._shutdown_trigger.set_result(True)
-        if hasattr(self, 'connection'):
-            self.connection.close()
+        if getattr(self, '_origin_pid', None) == os.getpid():
+            MPFuture._active_futures.pop(self._uid, None)
+        if getattr(self, '_aio_event', None):
+            self._aio_event.set()
+
+    def __getstate__(self):
+        return dict(_sender_pipe=self._sender_pipe, _shared_state_code=self._shared_state_code,
+                    _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock,
+                    _result=self._result, _exception=self._exception)
+
+    def __setstate__(self, state):
+        self._sender_pipe = state['_sender_pipe']
+        self._shared_state_code = state['_shared_state_code']
+        self._origin_pid, self._uid = state['_origin_pid'], state['_uid']
+        self._result, self._exception = state['_result'], state['_exception']
+        self._use_lock = state['_use_lock']
+
+        self._waiters, self._done_callbacks = [], []
+        self._condition = threading.Condition()
+        self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 4 - 0
tests/test_averaging.py

@@ -423,3 +423,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(x2.grad, grad_avg)
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
+
+    averager1.shutdown()
+    averager2.shutdown()
+    dht.shutdown()

+ 255 - 74
tests/test_util_modules.py

@@ -1,129 +1,310 @@
 import asyncio
-from concurrent.futures import CancelledError
+import concurrent.futures
+import multiprocessing as mp
+import random
+import time
 
-import numpy as np
 import pytest
 import torch
+import numpy as np
 
+import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.mpfuture import FutureStateError
+from hivemind.utils.mpfuture import InvalidStateError
 
 
+@pytest.mark.forked
 def test_mpfuture_result():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    f1.set_result(321)
-    assert f2.result() == 321
-    assert f1.result() == 321
+    future = hivemind.MPFuture()
 
-    for future in [f1, f2]:
-        with pytest.raises(FutureStateError):
-            future.set_result(123)
-        with pytest.raises(FutureStateError):
-            future.set_exception(ValueError())
-        assert future.cancel() is False
-        assert future.done() and not future.running() and not future.cancelled()
+    def _proc(future):
+        with pytest.raises(RuntimeError):
+            future.result()  # only creator process can await result
+
+        future.set_result(321)
+
+    p = mp.Process(target=_proc, args=(future,))
+    p.start()
+    p.join()
 
-    f1, f2 = hivemind.MPFuture.make_pair()
-    with pytest.raises(TimeoutError):
-        f1.result(timeout=1e-3)
+    assert future.result() == 321
+    assert future.exception() is None
+    assert future.cancel() is False
+    assert future.done() and not future.running() and not future.cancelled()
 
-    f2.set_result(['abacaba', 123])
-    assert f1.result() == ['abacaba', 123]
+    future = hivemind.MPFuture()
+    with pytest.raises(concurrent.futures.TimeoutError):
+        future.result(timeout=1e-3)
 
+    future.set_result(['abacaba', 123])
+    assert future.result() == ['abacaba', 123]
 
+
+@pytest.mark.forked
 def test_mpfuture_exception():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    with pytest.raises(TimeoutError):
-        f1.exception(timeout=1e-3)
+    future = hivemind.MPFuture()
+    with pytest.raises(concurrent.futures.TimeoutError):
+        future.exception(timeout=1e-3)
 
-    f2.set_exception(NotImplementedError())
+    def _proc(future):
+        future.set_exception(NotImplementedError())
 
-    for future in [f1, f2]:
-        assert isinstance(future.exception(), NotImplementedError)
-        with pytest.raises(NotImplementedError):
-            future.result()
-        assert future.cancel() is False
-        assert future.done() and not future.running() and not future.cancelled()
+    p = mp.Process(target=_proc, args=(future,))
+    p.start()
+    p.join()
+
+    assert isinstance(future.exception(), NotImplementedError)
+    with pytest.raises(NotImplementedError):
+        future.result()
+    assert future.cancel() is False
+    assert future.done() and not future.running() and not future.cancelled()
 
 
+@pytest.mark.forked
 def test_mpfuture_cancel():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    assert not f2.cancelled()
-    f1.cancel()
-    for future in [f1, f2]:
-        with pytest.raises(CancelledError):
+    future = hivemind.MPFuture()
+    assert not future.cancelled()
+    future.cancel()
+    evt = mp.Event()
+
+    def _proc():
+        with pytest.raises(concurrent.futures.CancelledError):
             future.result()
-        with pytest.raises(CancelledError):
+        with pytest.raises(concurrent.futures.CancelledError):
             future.exception()
-        with pytest.raises(FutureStateError):
+        with pytest.raises(InvalidStateError):
             future.set_result(123)
-        with pytest.raises(FutureStateError):
+        with pytest.raises(InvalidStateError):
             future.set_exception(NotImplementedError())
         assert future.cancelled() and future.done() and not future.running()
+        evt.set()
 
+    p = mp.Process(target=_proc)
+    p.start()
+    p.join()
+    assert evt.is_set()
 
+
+@pytest.mark.forked
 def test_mpfuture_status():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    assert f1.set_running_or_notify_cancel() is True
-    for future in [f1, f2]:
-        assert future.running() and not future.done() and not future.cancelled()
-        with pytest.raises(RuntimeError):
-            future.set_running_or_notify_cancel()
-    f2.cancel()
-    for future in [f1, f2]:
+    evt = mp.Event()
+    future = hivemind.MPFuture()
+
+    def _proc1(future):
+        assert future.set_running_or_notify_cancel() is True
+        evt.set()
+
+    p = mp.Process(target=_proc1, args=(future,))
+    p.start()
+    p.join()
+    assert evt.is_set()
+    evt.clear()
+
+    assert future.running() and not future.done() and not future.cancelled()
+    with pytest.raises(InvalidStateError):
+        future.set_running_or_notify_cancel()
+
+    future = hivemind.MPFuture()
+    assert future.cancel()
+
+    def _proc2(future):
         assert not future.running() and future.done() and future.cancelled()
         assert future.set_running_or_notify_cancel() is False
+        evt.set()
 
-    f1, f2 = hivemind.MPFuture.make_pair()
-    f1.cancel()
-    for future in [f1, f2]:
-        assert future.set_running_or_notify_cancel() is False
+    p = mp.Process(target=_proc2, args=(future,))
+    p.start()
+    p.join()
+    evt.set()
+
+    future2 = hivemind.MPFuture()
+    future2.cancel()
+    assert future2.set_running_or_notify_cancel() is False
 
 
 @pytest.mark.asyncio
 async def test_await_mpfuture():
-    # await result
-    f1, f2 = hivemind.MPFuture.make_pair()
+    # await result from the same process, but a different coroutine
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
-    async def wait_and_assign():
+    async def wait_and_assign_async():
         assert f2.set_running_or_notify_cancel() is True
         await asyncio.sleep(0.1)
-        f2.set_result((123, 'ololo'))
+        f1.set_result((123, 'ololo'))
+        f2.set_result((456, 'pyshpysh'))
+
+    asyncio.create_task(wait_and_assign_async())
 
-    asyncio.create_task(wait_and_assign())
-    for future in [f1, f2]:
-        res = await future
-        assert res == (123, 'ololo')
+    assert (await asyncio.gather(f1, f2)) == [(123, 'ololo'), (456, 'pyshpysh')]
+
+    # await result from separate processes
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
+
+    def wait_and_assign(future, value):
+        time.sleep(0.1 * random.random())
+        future.set_result(value)
+
+    p1 = mp.Process(target=wait_and_assign, args=(f1, 'abc'))
+    p2 = mp.Process(target=wait_and_assign, args=(f2, 'def'))
+    for p in p1, p2:
+        p.start()
+
+    assert (await asyncio.gather(f1, f2)) == ['abc', 'def']
+    for p in p1, p2:
+        p.join()
 
     # await cancel
-    f1, f2 = hivemind.MPFuture.make_pair()
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
-    async def wait_and_cancel():
-        await asyncio.sleep(0.1)
+    def wait_and_cancel():
+        time.sleep(0.01)
+        f2.set_result(123456)
+        time.sleep(0.1)
         f1.cancel()
 
-    asyncio.create_task(wait_and_cancel())
-    for future in [f1, f2]:
-        with pytest.raises(CancelledError):
-            await future
+    p = mp.Process(target=wait_and_cancel)
+    p.start()
+
+    with pytest.raises(asyncio.CancelledError):
+        # note: it is intended that MPFuture raises Cancel
+        await asyncio.gather(f1, f2)
+
+    p.join()
 
     # await exception
-    f1, f2 = hivemind.MPFuture.make_pair()
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
-    async def wait_and_raise():
-        await asyncio.sleep(0.1)
-        f1.set_exception(SystemError())
+    def wait_and_raise():
+        time.sleep(0.01)
+        f2.set_result(123456)
+        time.sleep(0.1)
+        f1.set_exception(ValueError('we messed up'))
+
+    p = mp.Process(target=wait_and_raise)
+    p.start()
+
+    with pytest.raises(ValueError):
+        # note: it is intended that MPFuture raises Cancel
+        await asyncio.gather(f1, f2)
+
+    p.join()
+
+
+@pytest.mark.forked
+def test_mpfuture_bidirectional():
+    evt = mp.Event()
+    future_from_main = hivemind.MPFuture()
+
+    def _future_creator():
+        future_from_fork = hivemind.MPFuture()
+        future_from_main.set_result(('abc', future_from_fork))
+
+        if future_from_fork.result() == ['we', 'need', 'to', 'go', 'deeper']:
+            evt.set()
+
+    p = mp.Process(target=_future_creator)
+    p.start()
+
+    out = future_from_main.result()
+    assert isinstance(out[1], hivemind.MPFuture)
+    out[1].set_result(['we', 'need', 'to', 'go', 'deeper'])
+
+    p.join()
+    assert evt.is_set()
+
+
+@pytest.mark.forked
+def test_mpfuture_done_callback():
+    receiver, sender = mp.Pipe(duplex=False)
+    events = [mp.Event() for _ in range(5)]
+
+    def _future_creator():
+        future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
+
+        def _check_result_and_set(future):
+            assert future.done()
+            assert future.result() == 123
+            events[0].set()
+
+        future1.add_done_callback(_check_result_and_set)
+        future1.add_done_callback(lambda future: events[1].set())
+        future2.add_done_callback(lambda future: events[2].set())
+        future3.add_done_callback(lambda future: events[3].set())
+
+        sender.send((future1, future2))
+        future2.cancel()  # trigger future2 callback from the same process
+
+        events[0].wait()
+        future1.add_done_callback(lambda future: events[4].set())  # schedule callback after future1 is already finished
+
+    p = mp.Process(target=_future_creator)
+    p.start()
+
+    future1, future2 = receiver.recv()
+    future1.set_result(123)
+
+    with pytest.raises(RuntimeError):
+        future1.add_done_callback(lambda future: (1, 2, 3))
+
+    p.join()
+    events[0].wait(1)
+    events[1].wait(1)
+    assert future1.done() and not future1.cancelled()
+    assert future2.done() and future2.cancelled()
+    assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
+    assert not events[3].is_set()
+
+
+@pytest.mark.forked
+def test_many_futures():
+    evt = mp.Event()
+    receiver, sender = mp.Pipe()
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
+    assert len(hivemind.MPFuture._active_futures) == 1000
+
+    def _run_peer():
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
+        assert len(hivemind.MPFuture._active_futures) == 500
+
+        for i, future in enumerate(random.sample(main_futures, 300)):
+            if random.random() < 0.5:
+                future.set_result(i)
+            else:
+                future.set_exception(ValueError(f"{i}"))
+
+        sender.send(fork_futures[:-100])
+        for future in fork_futures[-100:]:
+            future.cancel()
+
+        evt.wait()
+
+        assert len(hivemind.MPFuture._active_futures) == 200
+        for future in fork_futures:
+            future.cancel()
+        assert len(hivemind.MPFuture._active_futures) == 0
+
+    p = mp.Process(target=_run_peer)
+    p.start()
+
+    some_fork_futures = receiver.recv()
+    assert len(hivemind.MPFuture._active_futures) == 700
+
+    for future in some_fork_futures:
+        future.set_running_or_notify_cancel()
+    for future in random.sample(some_fork_futures, 200):
+        future.set_result(321)
 
-    asyncio.create_task(wait_and_raise())
-    for future in [f1, f2]:
-        with pytest.raises(SystemError):
-            await future
+    time.sleep(0.5)
+    evt.set()
+    for future in main_futures:
+        future.cancel()
+    assert len(hivemind.MPFuture._active_futures) == 0
+    p.join()
 
 
 def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
@@ -139,7 +320,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
 
-    zeros = torch.zeros(5,5)
+    zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()