Przeglądaj źródła

Refactor MPFuture to use a single pipe/thread per process (#298)

- Removed hivemind.utils.threading.run_in_background and HIVEMIND_THREADS
- Refactored MPFuture to be a single object instead of a linked pair of objects
- MPFuture now uses a single process-wide pipe and thread, instead of spawning new pipe/thread for each future
- MPFuture.result/exception can now only be awaited from the process that created it
- MPFuture now returns the same exception types as regular future (and as asyncio.Future in __await__)
- Added more thorough tests for MPFuture

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 4 lat temu
rodzic
commit
200fbecdbf

+ 1 - 1
benchmarks/benchmark_averaging.py

@@ -6,7 +6,7 @@ import argparse
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind.utils import LOCALHOST, increase_file_limit, get_logger
+from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 
 
 
 

+ 1 - 1
benchmarks/benchmark_dht.py

@@ -6,7 +6,7 @@ from tqdm import trange
 
 
 import hivemind
 import hivemind
 import hivemind.server.expert_uid
 import hivemind.server.expert_uid
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 
 
 logger = hivemind.get_logger(__name__)
 logger = hivemind.get_logger(__name__)
 
 

+ 1 - 1
benchmarks/benchmark_throughput.py

@@ -9,7 +9,7 @@ import torch
 import hivemind
 import hivemind
 from hivemind import find_open_port
 from hivemind import find_open_port
 from hivemind.server import layers
 from hivemind.server import layers
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 
 

+ 2 - 2
examples/albert/README.md

@@ -40,7 +40,7 @@ wandb: Run `wandb offline` to turn off syncing.
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - if necessary, specify paths: `--dataset_path ./path/to/unpacked/data --tokenizer ./path/to/tokenizer/config` (see [default paths](https://github.com/learning-at-home/hivemind/blob/collaborative_albert_example/examples/albert/run_trainer.py#L63-L69) for reference)
   - run:
   - run:
 ```shell
 ```shell
-HIVEMIND_THREADS=64 python run_trainer.py \
+python run_trainer.py \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --experiment_prefix SAME_AS_IN_RUN_FIRST_PEER --initial_peers ONE_OR_MORE_PEERS --seed 42 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs
 ```
 ```
@@ -88,7 +88,7 @@ Here's an example of a full trainer script for Google Colab:
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !pip install transformers datasets sentencepiece torch_optimizer==0.1.0
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !git clone https://github.com/learning-at-home/hivemind && cd hivemind && pip install -e .
 !curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
 !curl -L YOUR_HOSTED_DATA | tar xzf -     # example: https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103.tar.gz
-!ulimit -n 4096 && HIVEMIND_THREADS=256 python ./hivemind/examples/albert/run_trainer.py \
+!ulimit -n 4096 && python ./hivemind/examples/albert/run_trainer.py \
  --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
  --client_mode --initial_peers ONE_OR_MORE_PEERS  --averaging_expiration 10 \
  --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
  --batch_size_lead 300 --per_device_train_batch_size 4 --gradient_accumulation_steps 1 \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \
  --logging_first_step --logging_steps 100  --output_dir ./outputs --overwrite_output_dir --logging_dir ./logs \

+ 10 - 10
hivemind/client/averaging/__init__.py

@@ -290,9 +290,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             weight = float(self.mode != AveragingMode.AUX)
             weight = float(self.mode != AveragingMode.AUX)
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
         assert isinstance(weight, (int, float)) and weight >= 0, f"Expected a positive int/float, got {type(weight)}"
 
 
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
         gather_binary = self.serializer.dumps(gather)  # serialize here to avoid loading modules in the averager process
-        self._outer_pipe.send(('_step', [], dict(future=_future, gather_binary=gather_binary, weight=weight,
+        self._outer_pipe.send(('_step', [], dict(future=future, gather_binary=gather_binary, weight=weight,
                                                  allow_retries=allow_retries, timeout=timeout)))
                                                  allow_retries=allow_retries, timeout=timeout)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
@@ -463,8 +463,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     async def _get_current_state_from_host_process(self):
     async def _get_current_state_from_host_process(self):
         """ Executed in the averager process inside rpc_download_state """
         """ Executed in the averager process inside rpc_download_state """
-        future, _future = MPFuture.make_pair()
-        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
+        future = MPFuture()
+        self._inner_pipe.send(('_TRIGGER_GET_CURRENT_STATE', future))
         return await future
         return await future
 
 
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
     def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
@@ -477,8 +477,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
         The exact contents of both metadata and tensors are determined by get_current_state method
         The exact contents of both metadata and tensors are determined by get_current_state method
         """
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_load_state_from_peers', [], dict(future=future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _load_state_from_peers(self, future: MPFuture):
     async def _load_state_from_peers(self, future: MPFuture):
@@ -537,8 +537,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
         :param wait: if True, return bits immediately. Otherwise return awaitable MPFuture
         :returns: averager's current group key bits (without prefix)
         :returns: averager's current group key bits (without prefix)
         """
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get_group_bits', [], dict(future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get_group_bits', [], dict(future=future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _get_group_bits(self, future: MPFuture):
     async def _get_group_bits(self, future: MPFuture):
@@ -549,9 +549,9 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
         :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
         :param group_bits: group bits (string of '0' or '1') to be used in averager's group key
         :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
         :param wait: if True, wait until the update is confirmed by the averager. Otherwise return immediately
         """
         """
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         assert all(bit in '01' for bit in group_bits)
         assert all(bit in '01' for bit in group_bits)
-        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=_future)))
+        self._outer_pipe.send(('_set_group_bits', [], dict(group_bits=group_bits, future=future)))
         return future.result() if wait else future
         return future.result() if wait else future
 
 
     async def _set_group_bits(self, group_bits: str, future: MPFuture):
     async def _set_group_bits(self, group_bits: str, future: MPFuture):

+ 9 - 7
hivemind/client/averaging/training.py

@@ -1,4 +1,5 @@
 """ An extension of averager that supports common optimization use cases. """
 """ An extension of averager that supports common optimization use cases. """
+from concurrent.futures import ThreadPoolExecutor
 from itertools import chain
 from itertools import chain
 from threading import Lock
 from threading import Lock
 from typing import Sequence, Dict, Iterator, Optional
 from typing import Sequence, Dict, Iterator, Optional
@@ -7,7 +8,7 @@ from contextlib import nullcontext
 import torch
 import torch
 
 
 from hivemind.client.averaging import DecentralizedAverager
 from hivemind.client.averaging import DecentralizedAverager
-from hivemind.utils import nested_flatten, nested_pack, get_logger, run_in_background
+from hivemind.utils import nested_flatten, nested_pack, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -39,6 +40,7 @@ class TrainingAverager(DecentralizedAverager):
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt, self.extra_tensors, self.local_step = opt, tuple(extra_tensors), 0
         self.opt_statistics = tuple(average_opt_statistics)
         self.opt_statistics = tuple(average_opt_statistics)
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
         self.average_parameters, self.average_gradients = average_parameters, average_gradients
+        self.step_executor = ThreadPoolExecutor(max_workers=1)
         self.lock_averager_step = Lock()
         self.lock_averager_step = Lock()
         if initialize_optimizer:
         if initialize_optimizer:
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
             initialize_optimizer_state(opt)  # note: this will run one optimizer step!
@@ -47,15 +49,15 @@ class TrainingAverager(DecentralizedAverager):
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
             averaged_tensors = [tensor.detach().cpu().float().clone() for tensor in self.local_tensors()]
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
 
 
-    @torch.no_grad()
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
     def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
-        """ Average optimizer weights and gradients with peers.
+        """
+        Average optimizer weights and gradients with peers.
+
         :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
         :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
         modifications occur during averaging step
         modifications occur during averaging step
-        :param wait: if True waits, otherwise returns Future
         """
         """
         if not wait:
         if not wait:
-            return run_in_background(self.step, data_lock, wait=True, **kwargs)
+            return self.step_executor.submit(self.step, data_lock, wait=True, **kwargs)
 
 
         # if data_lock is supplied, tensors might change during averaging, so we need to copy them
         # if data_lock is supplied, tensors might change during averaging, so we need to copy them
         use_old_local_tensors = data_lock is not None
         use_old_local_tensors = data_lock is not None
@@ -63,7 +65,7 @@ class TrainingAverager(DecentralizedAverager):
             data_lock = nullcontext()
             data_lock = nullcontext()
 
 
         local_tensors = list(self.local_tensors())
         local_tensors = list(self.local_tensors())
-        with self.lock_averager_step:
+        with self.lock_averager_step, torch.no_grad():
             # fill averager's tensors with current local tensors
             # fill averager's tensors with current local tensors
             with data_lock, self.get_tensors() as averaged_tensors:
             with data_lock, self.get_tensors() as averaged_tensors:
                 if use_old_local_tensors:
                 if use_old_local_tensors:
@@ -73,7 +75,7 @@ class TrainingAverager(DecentralizedAverager):
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                     averaged_tensor[...] = local_tensor.cpu().float()
                     averaged_tensor[...] = local_tensor.cpu().float()
 
 
-            # find a group and hopefully average tensors with peers, scaled by peer's weight
+            # find a group and hopefully average tensors with peers, use batch sizes as weights
             gathered = super().step(**kwargs)
             gathered = super().step(**kwargs)
             if gathered is not None:
             if gathered is not None:
                 # load averaged tensors back into model
                 # load averaged tensors back into model

+ 8 - 8
hivemind/dht/__init__.py

@@ -127,8 +127,8 @@ class DHT(mp.Process):
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :param kwargs: parameters forwarded to DHTNode.get_many_by_id
         :returns: (value, expiration time); if value was not found, returns None
         :returns: (value, expiration time); if value was not found, returns None
         """
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=_future, **kwargs)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get', [], dict(key=key, latest=latest, future=future, **kwargs)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
     async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
     async def _get(self, key: DHTKey, latest: bool, future: MPFuture, **kwargs):
@@ -153,9 +153,9 @@ class DHT(mp.Process):
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
-        future, _future = MPFuture.make_pair()
+        future = MPFuture()
         self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
         self._outer_pipe.send(('_store', [], dict(key=key, value=value, expiration_time=expiration_time, subkey=subkey,
-                                                  future=_future, **kwargs)))
+                                                  future=future, **kwargs)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
     async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
     async def _store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
@@ -184,8 +184,8 @@ class DHT(mp.Process):
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
           or use asyncio.get_event_loop().run_in_executor(...) to prevent coroutine from blocking background DHT tasks
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         :note: when run_coroutine is called with wait=False, MPFuture can be cancelled to interrupt the task.
         """
         """
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_run_coroutine', [], dict(coro=coro, future=future)))
         return future if return_future else future.result()
         return future if return_future else future.result()
 
 
     async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
     async def _run_coroutine(self, coro: Callable[[DHT, DHTNode], Awaitable[ReturnType]],
@@ -226,8 +226,8 @@ class DHT(mp.Process):
         """
         """
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
         assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
-        future, _future = MPFuture.make_pair()
-        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        future = MPFuture()
+        self._outer_pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=future)))
         return future.result()
         return future.result()
 
 
     async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],
     async def _get_visible_address(self, num_peers: Optional[int], peers: Sequence[Endpoint],

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -6,7 +6,7 @@ import torch
 
 
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.server import Server
 from hivemind.server import Server
-from hivemind.utils.threading import increase_file_limit
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.server.layers import schedule_name_to_scheduler
 from hivemind.server.layers import schedule_name_to_scheduler
 
 

+ 7 - 7
hivemind/server/task_pool.py

@@ -14,7 +14,8 @@ from typing import List, Tuple, Dict, Any, Generator
 
 
 import torch
 import torch
 
 
-from hivemind.utils import MPFuture, get_logger, FutureStateError
+from hivemind.utils import get_logger
+from hivemind.utils.mpfuture import MPFuture, InvalidStateError
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 Task = namedtuple("Task", ("future", "args"))
 Task = namedtuple("Task", ("future", "args"))
@@ -89,15 +90,14 @@ class TaskPool(TaskPoolBase):
 
 
     def submit_task(self, *args: torch.Tensor) -> Future:
     def submit_task(self, *args: torch.Tensor) -> Future:
         """ Add task to this pool's queue, return Future for its output """
         """ Add task to this pool's queue, return Future for its output """
-        future1, future2 = MPFuture.make_pair()
-        task = Task(future1, args)
+        task = Task(MPFuture(), args)
         if self.get_task_size(task) > self.max_batch_size:
         if self.get_task_size(task) > self.max_batch_size:
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
             exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
-            future2.set_exception(exc)
+            task.future.set_exception(exc)
         else:
         else:
             self.tasks.put(task)
             self.tasks.put(task)
             self.undispatched_task_timestamps.put(time.time())
             self.undispatched_task_timestamps.put(time.time())
-        return future2
+        return task.future
 
 
     def iterate_minibatches(self, *args, **kwargs):
     def iterate_minibatches(self, *args, **kwargs):
         """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
         """ Form minibatches by grouping one or more tasks together up to self.max_batch_size """
@@ -127,7 +127,7 @@ class TaskPool(TaskPoolBase):
                 if task.future.set_running_or_notify_cancel():
                 if task.future.set_running_or_notify_cancel():
                     batch.append(task)
                     batch.append(task)
                     total_size += task_size
                     total_size += task_size
-            except FutureStateError as e:
+            except InvalidStateError as e:
                 logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
                 logger.debug(f"Failed to add task to batch: {task.future} raised {e}")
 
 
     def run(self, *args, **kwargs):
     def run(self, *args, **kwargs):
@@ -196,7 +196,7 @@ class TaskPool(TaskPoolBase):
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
             for task, task_outputs in zip(batch_tasks, outputs_per_task):
                 try:
                 try:
                     task.future.set_result(tuple(task_outputs))
                     task.future.set_result(tuple(task_outputs))
-                except FutureStateError as e:
+                except InvalidStateError as e:
                     logger.debug(f"Failed to send task result due to an exception: {e}")
                     logger.debug(f"Failed to send task result due to an exception: {e}")
 
 
     @property
     @property

+ 1 - 1
hivemind/utils/__init__.py

@@ -1,11 +1,11 @@
 from hivemind.utils.asyncio import *
 from hivemind.utils.asyncio import *
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.grpc import *
 from hivemind.utils.grpc import *
+from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
 from hivemind.utils.serializer import *
 from hivemind.utils.serializer import *
 from hivemind.utils.tensor_descr import *
 from hivemind.utils.tensor_descr import *
-from hivemind.utils.threading import *
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 6 - 3
hivemind/utils/compression.py

@@ -1,3 +1,5 @@
+import os
+from concurrent.futures import ThreadPoolExecutor
 from typing import Tuple, Sequence, Optional
 from typing import Tuple, Sequence, Optional
 
 
 import numpy as np
 import numpy as np
@@ -6,7 +8,7 @@ import warnings
 
 
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils.threading import run_in_background
+
 
 
 FP32_EPS = 1e-06
 FP32_EPS = 1e-06
 NUM_BYTES_FLOAT32 = 4
 NUM_BYTES_FLOAT32 = 4
@@ -17,6 +19,8 @@ UNIFORM_BUCKETS_STD_RANGE = 6
 FP16_MAX = 65_504
 FP16_MAX = 65_504
 UINT8_RANGE = 256
 UINT8_RANGE = 256
 
 
+COMPRESSION_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTILE_COMPRESSION_THREADS", 128)))
+
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 warnings.filterwarnings("ignore", message="The given NumPy array is not writeable", category=UserWarning)
 
 
 
 
@@ -48,8 +52,7 @@ def _quantile_qq_approximation(array: np.array, n_quantiles: int, min_chunk_size
     jobs = []
     jobs = []
     for i in range(num_chunks):
     for i in range(num_chunks):
         chunk = slice(chunk_size * i, chunk_size * (i + 1))
         chunk = slice(chunk_size * i, chunk_size * (i + 1))
-        jobs.append(run_in_background(
-            np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
+        jobs.append(COMPRESSION_EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))
 
 
     for job in jobs:
     for job in jobs:
         job.result()
         job.result()

+ 0 - 14
hivemind/utils/threading.py → hivemind/utils/limits.py

@@ -1,21 +1,7 @@
-import os
-from concurrent.futures import Future, ThreadPoolExecutor
-
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
-EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
-
-
-def run_in_background(func: callable, *args, **kwargs) -> Future:
-    """ run func(*args, **kwargs) in background and return Future for its outputs """
-    global EXECUTOR_PID, GLOBAL_EXECUTOR
-    if os.getpid() != EXECUTOR_PID:
-        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("HIVEMIND_THREADS", 128)))
-        EXECUTOR_PID = os.getpid()
-    return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
-
 
 
 def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
 def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
     """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """
     """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """

+ 212 - 121
hivemind/utils/mpfuture.py

@@ -2,171 +2,262 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import concurrent.futures._base as base
 import concurrent.futures._base as base
+from contextlib import nullcontext
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
-import time
-from functools import lru_cache
-from typing import Optional, Tuple, Generic, TypeVar
+import os
+import threading
+import uuid
+from enum import Enum, auto
+from typing import Generic, TypeVar, Dict, Optional, Any, Callable
 
 
-from hivemind.utils.threading import run_in_background
+import torch    # used for py3.7-compatible shared memory
 
 
+from hivemind.utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+# flavour types
 ResultType = TypeVar('ResultType')
 ResultType = TypeVar('ResultType')
+PID, UID, State, PipeEnd = int, int, str, mp.connection.Connection
+ALL_STATES = base.PENDING, base.RUNNING, base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED
+TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
 
 
+try:
+    from concurrent.futures import InvalidStateError
+except ImportError:
+    # Python 3.7 doesn't raise concurrent.futures.InvalidStateError for repeating set_result/set_exception calls and
+    # doesn't even define this error. In this module, we simulate the Python 3.8+ behavior,
+    # defining and raising this error if necessary.
+    class InvalidStateError(Exception):
+        """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
 
 
-class FutureStateError(RuntimeError):
-    """Raised when attempting to change state of a future in a terminal state (e.g. finished)"""
-    pass
+
+class UpdateType(Enum):
+    RESULT = auto()
+    EXCEPTION = auto()
+    CANCEL = auto()
 
 
 
 
 class MPFuture(base.Future, Generic[ResultType]):
 class MPFuture(base.Future, Generic[ResultType]):
-    """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
+    """
+    A version of concurrent.futures.Future / asyncio.Future that can be fulfilled from a separate process.
+    Any process can access future status and set the result / exception and check for state.
+    However, only the original process (i.e. the process that created the future) can await the result or exception.
+
+    :param use_lock: if True, operations with MPFuture use a global lock to prevent concurrent writes to the same pipe;
+      If set to False, writing to this future ignores global lock, slightly improving performance, but making user
+      responsible for avoiding concurrent set_result / set_exception calls to futures with the same process of origin.
+    :param loop: if specified, overrides default asyncio event loop for the purpose of awaiting MPFuture
+
+    :note: This is an internal primitive that is not guaranteed to work outside of hivemind applications.
+     More specifically, there are two known limitations:
+       - MPFuture works between processes created through inheritance (e.g. fork), *not* for independent processes
+       - MPFuture is deterministic if only one process can call set_result/set_exception/set_running_or_notify_cancel
+         and only the origin process can call result/exception/cancel.
+    """
+    _initialization_lock = mp.Lock()  # global lock that prevents simultaneous initialization of two processes
+    _update_lock = mp.Lock()  # global lock that prevents simultaneous writing to the same pipe
+    _global_sender_pipe: Optional[PipeEnd] = None  # a pipe that is used to send results/exceptions to this process
+    _pipe_waiter_thread: Optional[threading.Thread] = None  # process-specific thread that receives results/exceptions
+    _active_futures: Optional[Dict[UID, MPFuture]] = None  # pending or running futures originated from current process
+    _active_pid: Optional[PID] = None  # pid of currently active process; used to handle forks natively
 
 
-    TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
+    def __init__(self, use_lock: bool = True, loop: Optional[asyncio.BaseEventLoop] = None):
+        self._origin_pid, self._uid = os.getpid(), uuid.uuid4().int
+        self._shared_state_code = torch.empty([], dtype=torch.uint8).share_memory_()
+        self._state_cache:  Dict[State, State] = {}  # mapping from global to cached local future used that makes updates immediately
+        # available on setter side; dictionary-based cache works because future can visit any state at most once
 
 
-    def __init__(self, connection: mp.connection.Connection):
-        """ manually create MPFuture. Please use MPFuture.make_pair instead """
+        base.Future.__init__(self)   # parent init is deferred because it uses self._shared_state_code
         self._state, self._result, self._exception = base.PENDING, None, None
         self._state, self._result, self._exception = base.PENDING, None, None
-        self.connection = connection
+        self._use_lock = use_lock
 
 
-    @classmethod
-    def make_pair(cls) -> Tuple[MPFuture, MPFuture]:
-        """ Create a pair of linked futures to be used in two processes """
-        connection1, connection2 = mp.Pipe()
-        return cls(connection1), cls(connection2)
+        if self._origin_pid != MPFuture._active_pid:
+            with MPFuture._initialization_lock:
+                if self._origin_pid != MPFuture._active_pid:
+                    # note: the second if is intentional, see https://en.wikipedia.org/wiki/Double-checked_locking
+                    self._initialize_mpfuture_backend()
+        assert self._uid not in MPFuture._active_futures
+        MPFuture._active_futures[self._uid] = self
+        self._sender_pipe = MPFuture._global_sender_pipe
 
 
-    def _send_updates(self):
-        """ Send updates to a paired MPFuture """
         try:
         try:
-            self.connection.send((self._state, self._result, self._exception))
-            if self._state in self.TERMINAL_STATES:
-                self._shutdown_trigger.set_result(True)
-                self.connection.close()
-            return True
-        except BrokenPipeError:
-            return False
+            self._loop = loop or asyncio.get_event_loop()
+            self._aio_event = asyncio.Event()
+        except RuntimeError:
+            self._loop, self._aio_event = None, None
 
 
-    def _recv_updates(self, timeout: Optional[float]):
-        """ Await updates from a paired MPFuture """
-        try:
-            future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
-                               return_when=base.FIRST_COMPLETED)[0].pop()
-            if future is self._shutdown_trigger:
-                raise BrokenPipeError()
-            if not future.result():
-                raise TimeoutError()
-            self._state, result, exception = self.connection.recv()
-            self._result = result if result is not None else self._result
-            self._exception = exception if exception is not None else self._exception
-            if self._state in self.TERMINAL_STATES:
-                self.connection.close()
-        except TimeoutError as e:
-            raise e
-        except (BrokenPipeError, OSError, EOFError) as e:
-            if self._state in (base.PENDING, base.RUNNING):
-                self._state, self._exception = base.FINISHED, e
-
-    def _await_terminal_state(self, timeout: Optional[float]):
-        """ Await updates until future is either finished, cancelled or got an exception """
-        time_left = float('inf') if timeout is None else timeout
-        time_before = time.monotonic()
-        while self._state not in self.TERMINAL_STATES and time_left > 0:
-            self._recv_updates(time_left if timeout else None)
-            time_spent = time.monotonic() - time_before
-            time_left, time_before = time_left - time_spent, time_before + time_spent
-
-    def _sync_updates(self):
-        """ Apply queued updates from a paired MPFuture without waiting for new ones """
+    @property
+    def _state(self) -> State:
+        shared_state = ALL_STATES[self._shared_state_code.item()]
+        return self._state_cache.get(shared_state, shared_state)
+
+    @_state.setter
+    def _state(self, new_state: State):
+        self._shared_state_code[...] = ALL_STATES.index(new_state)
+        if self._state in TERMINAL_STATES and self._loop is not None and not self._aio_event.is_set():
+            self._set_event_threadsafe()
+
+    def _set_event_threadsafe(self):
         try:
         try:
-            self._recv_updates(timeout=0)
-        except TimeoutError:
-            pass
+            loop = asyncio.get_running_loop()
+        except RuntimeError:
+            loop = None
+
+        async def _event_setter():
+            self._aio_event.set()
+
+        if loop == self.get_loop():
+            asyncio.create_task(_event_setter())
+        else:
+            asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
+
+    @classmethod
+    def _initialize_mpfuture_backend(cls):
+        pid = os.getpid()
+        logger.debug(f"Initializing MPFuture backend for pid {pid}")
+        assert pid != cls._active_pid, "already initialized"
+
+        receiver_pipe, cls._global_sender_pipe = mp.Pipe(duplex=False)
+        cls._active_pid, cls._active_futures = pid, {}
+        cls._pipe_waiter_thread = threading.Thread(target=cls._process_updates_in_background, args=[receiver_pipe],
+                                                   name=f'{__name__}.BACKEND', daemon=True)
+        cls._pipe_waiter_thread.start()
+
+    @classmethod
+    def _process_updates_in_background(cls, receiver_pipe: mp.connection.Connection):
+        pid = os.getpid()
+        while True:
+            try:
+                uid, update_type, payload = receiver_pipe.recv()
+                if uid not in cls._active_futures:
+                    logger.debug(f"Ignoring update to future with uid={uid}: the future is already done or destroyed")
+                elif update_type == UpdateType.RESULT:
+                    cls._active_futures.pop(uid).set_result(payload)
+                elif update_type == UpdateType.EXCEPTION:
+                    cls._active_futures.pop(uid).set_exception(payload)
+                elif update_type == UpdateType.CANCEL:
+                    cls._active_futures.pop(uid).cancel()
+                else:
+                    raise RuntimeError(f"Received unexpected update type {update_type}")
+            except (BrokenPipeError, EOFError):
+                logger.debug(f"Update pipe was was shut down unexpectedly (pid={pid})")
+            except Exception as e:
+                logger.exception(f"Could not retrieve update: caught {repr(e)} (pid={pid})")
+
+    def _send_update(self, update_type: UpdateType, payload: Any = None):
+        """ This method sends result, exception or cancel to the MPFuture origin. """
+        with MPFuture._update_lock if self._use_lock else nullcontext():
+            self._sender_pipe.send((self._uid, update_type, payload))
 
 
     def set_result(self, result: ResultType):
     def set_result(self, result: ResultType):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            raise FutureStateError(f"Can't set_result to a future that is {self._state} ({self})")
-        self._state, self._result = base.FINISHED, result
-        return self._send_updates()
-
-    def set_exception(self, exception: BaseException):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            raise FutureStateError(f"Can't set_exception to a future that is {self._state} ({self})")
-        self._state, self._exception = base.FINISHED, exception
-        self._send_updates()
+        if os.getpid() == self._origin_pid:
+            super().set_result(result)
+            MPFuture._active_futures.pop(self._uid, None)
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_result to a future that is {self._state} ({self._uid})")
+        else:
+            self._state_cache[self._state], self._result = base.FINISHED, result
+            self._send_update(UpdateType.RESULT, result)
+
+    def set_exception(self, exception: Optional[BaseException]):
+        if os.getpid() == self._origin_pid:
+            super().set_exception(exception)
+            MPFuture._active_futures.pop(self._uid, None)
+        elif self._state in TERMINAL_STATES:
+            raise InvalidStateError(f"Can't set_exception to a future that is {self._state} ({self._uid})")
+        else:
+            self._state_cache[self._state], self._exception = base.FINISHED, exception
+            self._send_update(UpdateType.EXCEPTION, exception)
+
+    def cancel(self) -> bool:
+        if os.getpid() == self._origin_pid:
+            MPFuture._active_futures.pop(self._uid, None)
+            return super().cancel()
+        elif self._state in [base.RUNNING, base.FINISHED]:
+            return False
+        else:
+            self._state_cache[self._state] = base.CANCELLED
+            self._send_update(UpdateType.CANCEL)
+            return True
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        self._sync_updates()
         if self._state == base.PENDING:
         if self._state == base.PENDING:
             self._state = base.RUNNING
             self._state = base.RUNNING
-            return self._send_updates()
+            return True
         elif self._state == base.CANCELLED:
         elif self._state == base.CANCELLED:
             return False
             return False
         else:
         else:
-            raise FutureStateError(f"Can't set_running_or_notify_cancel to a future that is in {self._state} ({self})")
-
-    def cancel(self):
-        self._sync_updates()
-        if self._state in self.TERMINAL_STATES:
-            return False
-        self._state, self._exception = base.CANCELLED, base.CancelledError()
-        return self._send_updates()
+            raise InvalidStateError(f"Can't set_running_or_notify_cancel when future is in {self._state} ({self._uid})")
 
 
     def result(self, timeout: Optional[float] = None) -> ResultType:
     def result(self, timeout: Optional[float] = None) -> ResultType:
-        self._await_terminal_state(timeout)
-        if self._exception is not None:
+        if self._state not in TERMINAL_STATES:
+            if os.getpid() != self._origin_pid:
+                raise RuntimeError("Only the process that created MPFuture can await result")
+            return super().result(timeout)
+        elif self._state == base.CANCELLED:
+            raise base.CancelledError()
+        elif self._exception:
             raise self._exception
             raise self._exception
-        return self._result
+        else:
+            return self._result
 
 
-    def exception(self, timeout=None) -> BaseException:
-        self._await_terminal_state(timeout)
-        if self._state == base.CANCELLED:
+    def exception(self, timeout: Optional[float] = None) -> Optional[BaseException]:
+        if self._state not in TERMINAL_STATES:
+            if os.getpid() != self._origin_pid:
+                raise RuntimeError("Only the process that created MPFuture can await exception")
+            return super().exception(timeout)
+        elif self._state == base.CANCELLED:
             raise base.CancelledError()
             raise base.CancelledError()
         return self._exception
         return self._exception
 
 
     def done(self) -> bool:
     def done(self) -> bool:
-        self._sync_updates()
-        return self._state in self.TERMINAL_STATES
+        return self._state in TERMINAL_STATES
 
 
     def running(self):
     def running(self):
-        self._sync_updates()
         return self._state == base.RUNNING
         return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
-        self._sync_updates()
         return self._state == base.CANCELLED
         return self._state == base.CANCELLED
 
 
-    def add_done_callback(self, callback):
-        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
-
-    def remove_done_callback(self, callback):
-        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
+    def add_done_callback(self, callback: Callable[[MPFuture], None]):
+        if os.getpid() != self._origin_pid:
+            raise RuntimeError("Only the process that created MPFuture can set callbacks")
+        return super().add_done_callback(callback)
 
 
-    def get_loop(self):
-        raise NotImplementedError(f"MPFuture doesn't support get_loop")
-
-    @property
-    @lru_cache()
-    def _shutdown_trigger(self):
-        return base.Future()
-
-    def __repr__(self):
-        self._sync_updates()
-        if self._state == base.FINISHED:
-            if self._exception:
-                return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
-            else:
-                return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
-        else:
-            return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
+    def get_loop(self) -> Optional[asyncio.BaseEventLoop]:
+        return self._loop
 
 
     def __await__(self):
     def __await__(self):
-        yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
-        if self._exception:
-            raise self._exception
-        return self._result
+        if not self._aio_event:
+            raise RuntimeError("Can't await: MPFuture was created with no event loop")
+        yield from self._aio_event.wait().__await__()
+        try:
+            return super().result(timeout=0)
+        except base.CancelledError:
+            raise asyncio.CancelledError()
 
 
     def __del__(self):
     def __del__(self):
-        self._shutdown_trigger.set_result(True)
-        if hasattr(self, 'connection'):
-            self.connection.close()
+        if getattr(self, '_origin_pid', None) == os.getpid():
+            MPFuture._active_futures.pop(self._uid, None)
+        if getattr(self, '_aio_event', None):
+            self._aio_event.set()
+
+    def __getstate__(self):
+        return dict(_sender_pipe=self._sender_pipe, _shared_state_code=self._shared_state_code,
+                    _origin_pid=self._origin_pid, _uid=self._uid, _use_lock=self._use_lock,
+                    _result=self._result, _exception=self._exception)
+
+    def __setstate__(self, state):
+        self._sender_pipe = state['_sender_pipe']
+        self._shared_state_code = state['_shared_state_code']
+        self._origin_pid, self._uid = state['_origin_pid'], state['_uid']
+        self._result, self._exception = state['_result'], state['_exception']
+        self._use_lock = state['_use_lock']
+
+        self._waiters, self._done_callbacks = [], []
+        self._condition = threading.Condition()
+        self._aio_event, self._loop = None, None
+        self._state_cache = {}

+ 4 - 0
tests/test_averaging.py

@@ -423,3 +423,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(x2.grad, grad_avg)
         assert torch.allclose(x2.grad, grad_avg)
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
+
+    averager1.shutdown()
+    averager2.shutdown()
+    dht.shutdown()

+ 255 - 74
tests/test_util_modules.py

@@ -1,129 +1,310 @@
 import asyncio
 import asyncio
-from concurrent.futures import CancelledError
+import concurrent.futures
+import multiprocessing as mp
+import random
+import time
 
 
-import numpy as np
 import pytest
 import pytest
 import torch
 import torch
+import numpy as np
 
 
+import hivemind
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.dht_pb2_grpc import DHTStub
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
 from hivemind.proto.runtime_pb2_grpc import ConnectionHandlerStub
-import hivemind
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
 from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
 from hivemind.utils.asyncio import amap_in_executor, aiter, aenumerate, achain, anext, azip
-from hivemind.utils.mpfuture import FutureStateError
+from hivemind.utils.mpfuture import InvalidStateError
 
 
 
 
+@pytest.mark.forked
 def test_mpfuture_result():
 def test_mpfuture_result():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    f1.set_result(321)
-    assert f2.result() == 321
-    assert f1.result() == 321
+    future = hivemind.MPFuture()
 
 
-    for future in [f1, f2]:
-        with pytest.raises(FutureStateError):
-            future.set_result(123)
-        with pytest.raises(FutureStateError):
-            future.set_exception(ValueError())
-        assert future.cancel() is False
-        assert future.done() and not future.running() and not future.cancelled()
+    def _proc(future):
+        with pytest.raises(RuntimeError):
+            future.result()  # only creator process can await result
+
+        future.set_result(321)
+
+    p = mp.Process(target=_proc, args=(future,))
+    p.start()
+    p.join()
 
 
-    f1, f2 = hivemind.MPFuture.make_pair()
-    with pytest.raises(TimeoutError):
-        f1.result(timeout=1e-3)
+    assert future.result() == 321
+    assert future.exception() is None
+    assert future.cancel() is False
+    assert future.done() and not future.running() and not future.cancelled()
 
 
-    f2.set_result(['abacaba', 123])
-    assert f1.result() == ['abacaba', 123]
+    future = hivemind.MPFuture()
+    with pytest.raises(concurrent.futures.TimeoutError):
+        future.result(timeout=1e-3)
 
 
+    future.set_result(['abacaba', 123])
+    assert future.result() == ['abacaba', 123]
 
 
+
+@pytest.mark.forked
 def test_mpfuture_exception():
 def test_mpfuture_exception():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    with pytest.raises(TimeoutError):
-        f1.exception(timeout=1e-3)
+    future = hivemind.MPFuture()
+    with pytest.raises(concurrent.futures.TimeoutError):
+        future.exception(timeout=1e-3)
 
 
-    f2.set_exception(NotImplementedError())
+    def _proc(future):
+        future.set_exception(NotImplementedError())
 
 
-    for future in [f1, f2]:
-        assert isinstance(future.exception(), NotImplementedError)
-        with pytest.raises(NotImplementedError):
-            future.result()
-        assert future.cancel() is False
-        assert future.done() and not future.running() and not future.cancelled()
+    p = mp.Process(target=_proc, args=(future,))
+    p.start()
+    p.join()
+
+    assert isinstance(future.exception(), NotImplementedError)
+    with pytest.raises(NotImplementedError):
+        future.result()
+    assert future.cancel() is False
+    assert future.done() and not future.running() and not future.cancelled()
 
 
 
 
+@pytest.mark.forked
 def test_mpfuture_cancel():
 def test_mpfuture_cancel():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    assert not f2.cancelled()
-    f1.cancel()
-    for future in [f1, f2]:
-        with pytest.raises(CancelledError):
+    future = hivemind.MPFuture()
+    assert not future.cancelled()
+    future.cancel()
+    evt = mp.Event()
+
+    def _proc():
+        with pytest.raises(concurrent.futures.CancelledError):
             future.result()
             future.result()
-        with pytest.raises(CancelledError):
+        with pytest.raises(concurrent.futures.CancelledError):
             future.exception()
             future.exception()
-        with pytest.raises(FutureStateError):
+        with pytest.raises(InvalidStateError):
             future.set_result(123)
             future.set_result(123)
-        with pytest.raises(FutureStateError):
+        with pytest.raises(InvalidStateError):
             future.set_exception(NotImplementedError())
             future.set_exception(NotImplementedError())
         assert future.cancelled() and future.done() and not future.running()
         assert future.cancelled() and future.done() and not future.running()
+        evt.set()
 
 
+    p = mp.Process(target=_proc)
+    p.start()
+    p.join()
+    assert evt.is_set()
 
 
+
+@pytest.mark.forked
 def test_mpfuture_status():
 def test_mpfuture_status():
-    f1, f2 = hivemind.MPFuture.make_pair()
-    assert f1.set_running_or_notify_cancel() is True
-    for future in [f1, f2]:
-        assert future.running() and not future.done() and not future.cancelled()
-        with pytest.raises(RuntimeError):
-            future.set_running_or_notify_cancel()
-    f2.cancel()
-    for future in [f1, f2]:
+    evt = mp.Event()
+    future = hivemind.MPFuture()
+
+    def _proc1(future):
+        assert future.set_running_or_notify_cancel() is True
+        evt.set()
+
+    p = mp.Process(target=_proc1, args=(future,))
+    p.start()
+    p.join()
+    assert evt.is_set()
+    evt.clear()
+
+    assert future.running() and not future.done() and not future.cancelled()
+    with pytest.raises(InvalidStateError):
+        future.set_running_or_notify_cancel()
+
+    future = hivemind.MPFuture()
+    assert future.cancel()
+
+    def _proc2(future):
         assert not future.running() and future.done() and future.cancelled()
         assert not future.running() and future.done() and future.cancelled()
         assert future.set_running_or_notify_cancel() is False
         assert future.set_running_or_notify_cancel() is False
+        evt.set()
 
 
-    f1, f2 = hivemind.MPFuture.make_pair()
-    f1.cancel()
-    for future in [f1, f2]:
-        assert future.set_running_or_notify_cancel() is False
+    p = mp.Process(target=_proc2, args=(future,))
+    p.start()
+    p.join()
+    evt.set()
+
+    future2 = hivemind.MPFuture()
+    future2.cancel()
+    assert future2.set_running_or_notify_cancel() is False
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_await_mpfuture():
 async def test_await_mpfuture():
-    # await result
-    f1, f2 = hivemind.MPFuture.make_pair()
+    # await result from the same process, but a different coroutine
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
 
-    async def wait_and_assign():
+    async def wait_and_assign_async():
         assert f2.set_running_or_notify_cancel() is True
         assert f2.set_running_or_notify_cancel() is True
         await asyncio.sleep(0.1)
         await asyncio.sleep(0.1)
-        f2.set_result((123, 'ololo'))
+        f1.set_result((123, 'ololo'))
+        f2.set_result((456, 'pyshpysh'))
+
+    asyncio.create_task(wait_and_assign_async())
 
 
-    asyncio.create_task(wait_and_assign())
-    for future in [f1, f2]:
-        res = await future
-        assert res == (123, 'ololo')
+    assert (await asyncio.gather(f1, f2)) == [(123, 'ololo'), (456, 'pyshpysh')]
+
+    # await result from separate processes
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
+
+    def wait_and_assign(future, value):
+        time.sleep(0.1 * random.random())
+        future.set_result(value)
+
+    p1 = mp.Process(target=wait_and_assign, args=(f1, 'abc'))
+    p2 = mp.Process(target=wait_and_assign, args=(f2, 'def'))
+    for p in p1, p2:
+        p.start()
+
+    assert (await asyncio.gather(f1, f2)) == ['abc', 'def']
+    for p in p1, p2:
+        p.join()
 
 
     # await cancel
     # await cancel
-    f1, f2 = hivemind.MPFuture.make_pair()
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
 
-    async def wait_and_cancel():
-        await asyncio.sleep(0.1)
+    def wait_and_cancel():
+        time.sleep(0.01)
+        f2.set_result(123456)
+        time.sleep(0.1)
         f1.cancel()
         f1.cancel()
 
 
-    asyncio.create_task(wait_and_cancel())
-    for future in [f1, f2]:
-        with pytest.raises(CancelledError):
-            await future
+    p = mp.Process(target=wait_and_cancel)
+    p.start()
+
+    with pytest.raises(asyncio.CancelledError):
+        # note: it is intended that MPFuture raises Cancel
+        await asyncio.gather(f1, f2)
+
+    p.join()
 
 
     # await exception
     # await exception
-    f1, f2 = hivemind.MPFuture.make_pair()
+    f1, f2 = hivemind.MPFuture(), hivemind.MPFuture()
 
 
-    async def wait_and_raise():
-        await asyncio.sleep(0.1)
-        f1.set_exception(SystemError())
+    def wait_and_raise():
+        time.sleep(0.01)
+        f2.set_result(123456)
+        time.sleep(0.1)
+        f1.set_exception(ValueError('we messed up'))
+
+    p = mp.Process(target=wait_and_raise)
+    p.start()
+
+    with pytest.raises(ValueError):
+        # note: it is intended that MPFuture raises Cancel
+        await asyncio.gather(f1, f2)
+
+    p.join()
+
+
+@pytest.mark.forked
+def test_mpfuture_bidirectional():
+    evt = mp.Event()
+    future_from_main = hivemind.MPFuture()
+
+    def _future_creator():
+        future_from_fork = hivemind.MPFuture()
+        future_from_main.set_result(('abc', future_from_fork))
+
+        if future_from_fork.result() == ['we', 'need', 'to', 'go', 'deeper']:
+            evt.set()
+
+    p = mp.Process(target=_future_creator)
+    p.start()
+
+    out = future_from_main.result()
+    assert isinstance(out[1], hivemind.MPFuture)
+    out[1].set_result(['we', 'need', 'to', 'go', 'deeper'])
+
+    p.join()
+    assert evt.is_set()
+
+
+@pytest.mark.forked
+def test_mpfuture_done_callback():
+    receiver, sender = mp.Pipe(duplex=False)
+    events = [mp.Event() for _ in range(5)]
+
+    def _future_creator():
+        future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
+
+        def _check_result_and_set(future):
+            assert future.done()
+            assert future.result() == 123
+            events[0].set()
+
+        future1.add_done_callback(_check_result_and_set)
+        future1.add_done_callback(lambda future: events[1].set())
+        future2.add_done_callback(lambda future: events[2].set())
+        future3.add_done_callback(lambda future: events[3].set())
+
+        sender.send((future1, future2))
+        future2.cancel()  # trigger future2 callback from the same process
+
+        events[0].wait()
+        future1.add_done_callback(lambda future: events[4].set())  # schedule callback after future1 is already finished
+
+    p = mp.Process(target=_future_creator)
+    p.start()
+
+    future1, future2 = receiver.recv()
+    future1.set_result(123)
+
+    with pytest.raises(RuntimeError):
+        future1.add_done_callback(lambda future: (1, 2, 3))
+
+    p.join()
+    events[0].wait(1)
+    events[1].wait(1)
+    assert future1.done() and not future1.cancelled()
+    assert future2.done() and future2.cancelled()
+    assert events[0].is_set() and events[1].is_set() and events[2].is_set() and events[4].is_set()
+    assert not events[3].is_set()
+
+
+@pytest.mark.forked
+def test_many_futures():
+    evt = mp.Event()
+    receiver, sender = mp.Pipe()
+    main_futures = [hivemind.MPFuture() for _ in range(1000)]
+    assert len(hivemind.MPFuture._active_futures) == 1000
+
+    def _run_peer():
+        fork_futures = [hivemind.MPFuture() for _ in range(500)]
+        assert len(hivemind.MPFuture._active_futures) == 500
+
+        for i, future in enumerate(random.sample(main_futures, 300)):
+            if random.random() < 0.5:
+                future.set_result(i)
+            else:
+                future.set_exception(ValueError(f"{i}"))
+
+        sender.send(fork_futures[:-100])
+        for future in fork_futures[-100:]:
+            future.cancel()
+
+        evt.wait()
+
+        assert len(hivemind.MPFuture._active_futures) == 200
+        for future in fork_futures:
+            future.cancel()
+        assert len(hivemind.MPFuture._active_futures) == 0
+
+    p = mp.Process(target=_run_peer)
+    p.start()
+
+    some_fork_futures = receiver.recv()
+    assert len(hivemind.MPFuture._active_futures) == 700
+
+    for future in some_fork_futures:
+        future.set_running_or_notify_cancel()
+    for future in random.sample(some_fork_futures, 200):
+        future.set_result(321)
 
 
-    asyncio.create_task(wait_and_raise())
-    for future in [f1, f2]:
-        with pytest.raises(SystemError):
-            await future
+    time.sleep(0.5)
+    evt.set()
+    for future in main_futures:
+        future.cancel()
+    assert len(hivemind.MPFuture._active_futures) == 0
+    p.join()
 
 
 
 
 def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
 def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
@@ -139,7 +320,7 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     error = deserialize_torch_tensor(serialize_torch_tensor(X, CompressionType.UNIFORM_8BIT)) - X
     assert error.square().mean() < beta
     assert error.square().mean() < beta
 
 
-    zeros = torch.zeros(5,5)
+    zeros = torch.zeros(5, 5)
     for compression_type in CompressionType.values():
     for compression_type in CompressionType.values():
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()
         assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()