Selaa lähdekoodia

Simplify & explain hivemind.dht.DHT (#78)

* unused import

* remove loop.run_forever

* * explain how DHT stores stuff (hivemind/dht/__init__.py)
* implement more efficient DHT.first_k_active (use get_many)
* change uid_delimiter into a DHT instance property

* hivemind.DHT is now thread-safe on client side

* implement all mpfuture methods

* get_experts: allow returning future

* rollback run_in_background

* switch back to GLOBAL_EXECUTOR (tests should fail)

* instantiate executor in each process to avoid os locks

* rollback to minimize diff

* typo

* traverse_dht can now be cancelled

* node.store_many and node.get_many can now be cancelled

* address review by mryab@

* address review by mryab@

* address review by mryab@

* address review by mryab@

* address review by mryab@
justheuristic 5 vuotta sitten
vanhempi
commit
535318e249

+ 1 - 1
.circleci/config.yml

@@ -27,7 +27,7 @@ jobs:
           command: sudo python setup.py develop
           name: setup
       - run:
-          command: for test_file in tests/test*.py; do pytest $test_file --full-trace; done
+          command: pytest ./tests
           name: tests
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic

+ 4 - 4
hivemind/client/moe.py

@@ -104,7 +104,7 @@ class RemoteMixtureOfExperts(nn.Module):
         beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
         scores = np.zeros([batch_size, 1], dtype=np.float64)
 
-        delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
+        delimiters = np.array(self.dht.UID_DELIMITER)[None, None, None]  # pre-compute numpy array for fast concat
 
         for dim_index, dim_scores in enumerate(grid_scores):
             dim_scores = dim_scores.detach().cpu().numpy()
@@ -112,7 +112,7 @@ class RemoteMixtureOfExperts(nn.Module):
 
             # create all possible successsors from current beam
             dim_indices = np.arange(dim_scores.shape[1]).astype(str)
-            new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
+            new_candidates = beam[:, :, None] + delimiters + dim_indices[None, None, :]
             new_candidates = new_candidates.reshape([batch_size, -1])
 
             new_scores = scores[:, :, None] + dim_scores[:, None, :]
@@ -166,8 +166,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMETER):]
-            expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMETER)))
+            expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
+            expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
             grid_indices[i] = expert_indices
 
         scores_per_dim = [

+ 113 - 85
hivemind/dht/__init__.py

@@ -16,6 +16,8 @@ import asyncio
 import ctypes
 import multiprocessing as mp
 import warnings
+from collections import deque
+from concurrent.futures import ThreadPoolExecutor
 from typing import List, Optional, Sequence
 
 import uvloop
@@ -23,12 +25,12 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import MPFuture, Endpoint, run_in_background
+from hivemind.utils import MPFuture, Endpoint
 
 
 class DHT(mp.Process):
     """
-    A high-level interface to hivemind DHT. Runs a dht node in a background process.
+    High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
 
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
@@ -36,19 +38,45 @@ class DHT(mp.Process):
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
     :param kwargs: any other params will be forwarded to DHTNode upon creation
+
+    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
+    An expert identifier consists of:
+
+        * optional prefix that determines expert role, experiment name, etc.
+        * one or more integers that determine that expert's position in an N-dimensional grid
+
+    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
+    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
+    ``"ffn_expert", "ffn_expert.98", "ffn_expert.98.76", ..., "ffn_expert.98.76.54.32.10"``
+
+    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
+    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
+    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
+    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
+    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
+    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
+    This is done using DHT.first_k_active function. After selecting k best indices along first dimension, MoE moves
+    to the second dimension. It can find top-k pairs of indices (e.g. "expert.98.76") that start with one of k first
+    indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
+    This beam search explores one additional dimension per step and finds k best experts from across the DHT
+    in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
     """
-    UID_DELIMETER = '.'  # splits expert uids over this delimeter
-    EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
-    make_key = "{}::{}".format
+
+    UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+    #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
-                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
+                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
+                 receiver_threads: int = 1, expiration: float = 300, **kwargs):
         super().__init__()
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
-        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+        self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
+        self.expiration = expiration
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
-        self.node: Optional[DHTNode] = None  # initialized inside self.run only
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.daemon = daemon
@@ -62,16 +90,20 @@ class DHT(mp.Process):
         uvloop.install()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
-        self.node: DHTNode = loop.run_until_complete(DHTNode.create(
-            initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-            num_workers=self.max_workers or 1, **self.kwargs))
-        self._port.value = self.node.port
-        run_in_background(loop.run_forever)
-        self.ready.set()
+        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+
+        async def _run():
+            node = await DHTNode.create(
+                initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                num_workers=self.max_workers or 1, **self.kwargs)
+            self._port.value = node.port
+            self.ready.set()
+
+            while True:
+                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
-        while True:
-            method, args, kwargs = self._pipe.recv()
-            getattr(self, method)(*args, **kwargs)
+        loop.run_until_complete(_run())
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
@@ -85,7 +117,7 @@ class DHT(mp.Process):
     def shutdown(self) -> None:
         """ Shuts down the dht process """
         if self.is_alive():
-            self.kill()
+            self.terminate()
         else:
             warnings.warn("DHT shutdown has no effect: dht process is already not alive")
 
@@ -93,32 +125,27 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
 
-    def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
+    def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
+                    wait=True) -> List[Optional[RemoteExpert]]:
         """
         :param uids: find experts with these ids from across the DHT
-        :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
+        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+        :param wait: if True (default), return when experts are returned. Otherwise return a Future.
         :returns: a list of [RemoteExpert if found else None]
         """
+        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
-        return future.result()
+        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
+        return future.result() if wait else future
 
-    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
-        loop = asyncio.get_event_loop()
-        expiration = expiration or get_dht_time()
+    async def _get_experts(
+            self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
+        if expiration_time is None:
+            expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        keys = [self.make_key('expert', uid) for uid in uids]
-
-        response = asyncio.run_coroutine_threadsafe(
-            self.node.get_many(keys, expiration, num_workers=num_workers), loop).result()
-
-        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
-        for i, (key, uid) in enumerate(zip(keys, uids)):
-            maybe_endpoint, maybe_expiration = response[key]
-            if maybe_expiration is not None:  # if we found a value
-                experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
-
-        future.set_result(experts)
+        response = await node.get_many(uids, expiration_time, num_workers=num_workers)
+        future.set_result([RemoteExpert(uid, maybe_endpoint) if maybe_expiration_time else None
+                           for uid, (maybe_endpoint, maybe_expiration_time) in response.items()])
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
@@ -136,69 +163,70 @@ class DHT(mp.Process):
         if wait:
             return future.result(timeout)
 
-    def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
-        assert self.node is not None, "This method should only be accessed from inside .run method"
+    async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        loop = asyncio.get_event_loop()
-        expiration_time = get_dht_time() + self.EXPIRATION
-        unique_prefixes = set()
+        expiration_time = get_dht_time() + self.expiration
 
-        keys, values = [], []
+        data_to_store = {}
         for uid in uids:
-            uid_parts = uid.split(self.UID_DELIMETER)
-            keys.append(self.make_key('expert', uid))
-            values.append(endpoint)
-            unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
-
-        for prefix in unique_prefixes:
-            keys.append(self.make_key('prefix', prefix))
-            values.append(True)
-
-        store_ok = asyncio.run_coroutine_threadsafe(
-            self.node.store_many(keys, values, expiration_time, num_workers=num_workers), loop
-        ).result()
+            uid_parts = uid.split(self.UID_DELIMITER)
+            for i in range(len(uid_parts)):
+                uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
+                data_to_store[uid_prefix_i] = endpoint
+
+        store_keys, store_values = zip(*data_to_store.items())
+        store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
         if future is not None:
-            future.set_result([store_ok[key] for key in keys])
+            future.set_result([store_ok[key] for key in data_to_store.keys()])
 
-    def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
+    def first_k_active(self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None):
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
 
-        :param prefixes: a list of uid prefixes ordered from highest to lowest priority
+        :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param k: return at most *this many* active prefixes
-        :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
+        :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
+        :param chunk_size: dispatch this many requests in one task
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         """
-        assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
+        assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
-                        dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
+                        dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
+                             chunk_size=chunk_size or k, future=_future)))
         return future.result()
 
-    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
-        assert self.node is not None, "This method should only be accessed from inside .run method"
-        max_prefetch = max_prefetch or len(prefixes)
-        loop = asyncio.get_event_loop()
-        lookup_prefetch = [asyncio.run_coroutine_threadsafe(self.node.get(self.make_key('prefix', prefix)), loop)
-                           for prefix in prefixes[:max_prefetch]]
+    async def _first_k_active(
+            self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
+        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
+        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
         active_prefixes = []
 
-        for i, prefix in enumerate(prefixes):
-            _, maybe_expiration = lookup_prefetch[i].result()
-
-            if maybe_expiration is not None:
-                active_prefixes.append(prefix)
-                if len(active_prefixes) >= k:
-                    future.set_result(active_prefixes)
-                    for task in lookup_prefetch[i:]:
-                        task.cancel()
-                    return
-
-            # pre-dispatch the next request in line
-            if len(lookup_prefetch) < len(prefixes):
-                lookup_prefetch.append(
-                    asyncio.run_coroutine_threadsafe(
-                        self.node.get(self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop))
-
-        # could not find enough active prefixes; return what we can
+        pending_tasks = deque(
+            asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
+                                              num_workers=num_workers_per_chunk))
+            for chunk_i in range(min(max_prefetch + 1, total_chunks))
+        )  # pre-dispatch first task and up to max_prefetch additional tasks
+
+        for chunk_i in range(total_chunks):
+            # parse task results in chronological order, launch additional tasks on demand
+            response = await pending_tasks.popleft()
+            for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
+                if response[uid_prefix][1] is not None:  # found active peer
+                    active_prefixes.append(uid_prefix)
+                    # if we found enough active experts, finish immediately
+                    if len(active_prefixes) >= k:
+                        break
+            if len(active_prefixes) >= k:
+                for task in pending_tasks:
+                    task.cancel()
+                break
+
+            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
+            if pre_dispatch_chunk_i < total_chunks:
+                pending_tasks.append(asyncio.create_task(node.get_many(
+                    uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
+                    num_workers=num_workers_per_chunk)))
+
+        # return k active prefixes or as many as we could find
         future.set_result(active_prefixes)

+ 17 - 17
hivemind/dht/dht.proto

@@ -17,39 +17,39 @@ service DHT {
 message NodeInfo {
   // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
-  bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
-  int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+  bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
+  int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
 }
 
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
-  repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes values = 2;        // binary-encoded value for i-th key
-  repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes values = 2;           // binary-encoded value for i-th key
+  repeated double expiration_time = 3; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 4;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 5;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 
 message StoreResponse {
-  repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
-  NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+  repeated bool store_ok = 1;          // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
 }
 
 message FindRequest {
-  repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
-  NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+  repeated bytes keys = 1;             // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
 }
 
 message Peers {
   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-  repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
-  repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+  repeated bytes node_ids = 1;         // DHTID serialized with node_id.to_bytes()
+  repeated string endpoints = 2;       // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 
 message FindResponse {
-  repeated bytes values = 1;        // value for i-th key, b'' means not found locally
-  repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
-  repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
-  NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+  repeated bytes values = 1;           // value for i-th key, b'' means not found locally
+  repeated double expiration_time = 2; // expiration time for i-th value, only valid value is found
+  repeated Peers nearest = 3;          // peers ordered from nearest to farthest based on distance to i-th key
+  NodeInfo peer = 4;                   // respondent's node id, for you to update routing table
 }
 

+ 40 - 33
hivemind/dht/node.py

@@ -28,7 +28,7 @@ class DHTNode:
     Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
 
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
-    * store - send several (key, value, expiration) pairs to the same peer (like Kademlia STORE, but in bulk)
+    * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
     * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
@@ -37,10 +37,10 @@ class DHTNode:
 
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
-    - when requested to store(key: value, expiration), a node must store (key => value) at until expiration time
+    - when requested to store(key: value, expiration_time), a node must store (key => value) at until expiration time
       or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
       has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
-    - when requested to store(key: value, expiration, in_cache=True), stores (key => value) in a separate "cache".
+    - when requested to store(key: value, expiration_time, in_cache=True), stores (key => value) in a separate "cache".
       Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
 
     """
@@ -191,15 +191,15 @@ class DHTNode:
         store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
         return store_ok[key]
 
-    async def store_many(
-            self, keys: List[DHTKey], values: List[DHTValue], expiration: Union[DHTExpiration, List[DHTExpiration]],
-            exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
+    async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
+                         expiration_time: Union[DHTExpiration, List[DHTExpiration]],
+                         exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
         """
-        Traverse DHT to find up to best nodes to store multiple (key, value, expiration) pairs.
+        Traverse DHT to find up to best nodes to store multiple (key, value, expiration_time) pairs.
 
         :param keys: arbitrary serializable keys associated with each value
         :param values: serializable "payload" for each key
-        :param expiration: either one expiration time for all keys or individual expiration times (see class doc)
+        :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
@@ -207,13 +207,14 @@ class DHTNode:
             if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
         :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
         """
-        expiration = [expiration] * len(keys) if isinstance(expiration, DHTExpiration) else expiration
-        assert len(keys) == len(values) == len(expiration), "Please provide equal number of keys, values and expiration"
+        if isinstance(expiration_time, DHTExpiration):
+            expiration_time = [expiration_time] * len(keys)
+        assert len(keys) == len(values) == len(expiration_time), "Number of keys, values and expiration doesn't match."
 
         key_ids = list(map(DHTID.generate, keys))
         id_to_original_key = dict(zip(key_ids, keys))
         binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
-        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration)}
+        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration_time)}
         unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
 
         store_ok = {key: False for key in keys}  # outputs, updated during search
@@ -272,12 +273,16 @@ class DHTNode:
 
             store_finished_events[id_to_original_key[key_id]].set()
 
-        asyncio.create_task(self.find_nearest_nodes(
+        store_task = asyncio.create_task(self.find_nearest_nodes(
             queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
-        await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # await one (or all) store accepts
-        assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-        return store_ok
+        try:
+            await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
+            assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
+            return store_ok
+        except asyncio.CancelledError as e:
+            store_task.cancel()
+            raise e
 
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
         """
@@ -316,17 +321,17 @@ class DHTNode:
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
 
-        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
+        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration_time", "source_node_id"])
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
 
         # stage 1: value can be stored in our local cache
         for key_id in key_ids:
-            maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
-            if maybe_expiration is None:
-                maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
-            if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
-                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, self.node_id)
-                if maybe_expiration >= sufficient_expiration_time:
+            maybe_value, maybe_expiration_time = self.protocol.storage.get(key_id)
+            if maybe_expiration_time is None:
+                maybe_value, maybe_expiration_time = self.protocol.cache.get(key_id)
+            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
+                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id)
+                if maybe_expiration_time >= sufficient_expiration_time:
                     unfinished_key_ids.remove(key_id)
 
         # stage 2: traverse the DHT for any unfinished keys
@@ -341,11 +346,11 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
-            for key_id, (maybe_value, maybe_expiration, peers) in response.items():
+            for key_id, (maybe_value, maybe_expiration_time, peers) in response.items():
                 node_to_endpoint.update(peers)
-                if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
-                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
-                should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
+                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
+                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, peer)
+                should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time)
                 output[key_id] = list(peers.keys()), should_interrupt
             return output
 
@@ -356,10 +361,10 @@ class DHTNode:
 
         # stage 3: cache any new results depending on caching parameters
         for key_id, nearest_nodes in nearest_nodes_per_query.items():
-            latest_value_bytes, latest_expiration, latest_node_id = latest_results[key_id]
-            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it
+            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[key_id]
+            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
             if should_cache and self.cache_locally:
-                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
+                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time)
 
             if should_cache and self.cache_nearest:
                 num_cached_nodes = 0
@@ -367,16 +372,18 @@ class DHTNode:
                     if node_id == latest_node_id:
                         continue
                     asyncio.create_task(self.protocol.call_store(
-                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
+                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time],
+                        in_cache=True))
                     num_cached_nodes += 1
                     if num_cached_nodes >= self.cache_nearest:
                         break
 
         # stage 4: deserialize data and assemble function output
         find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
-        for key_id, (latest_value_bytes, latest_expiration, _) in latest_results.items():
-            if latest_expiration != -float('inf'):
-                find_result[id_to_original_key[key_id]] = self.serializer.loads(latest_value_bytes), latest_expiration
+        for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items():
+            if latest_expiration_time != -float('inf'):
+                latest_value = self.serializer.loads(latest_value_bytes)
+                find_result[id_to_original_key[key_id]] = (latest_value, latest_expiration_time)
             else:
                 find_result[id_to_original_key[key_id]] = None, None
         return find_result

+ 26 - 24
hivemind/dht/protocol.py

@@ -113,15 +113,15 @@ class DHTProtocol(dht_grpc.DHTServicer):
         return self.node_info
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
-                         expirations: Union[DHTExpiration, Sequence[DHTExpiration]],
+                         expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
                          in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
         """
-        Ask a recipient to store several (key, value : expiration) items or update their older value
+        Ask a recipient to store several (key, value : expiration_time) items or update their older value
 
         :param peer: request this peer to store the data
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param values: a list of N serialized values (bytes) for each respective key
-        :param expirations: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
@@ -129,13 +129,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
          if peer did not respond (e.g. due to timeout or congestion), returns None
         """
+        if isinstance(expiration_time, DHTExpiration):
+            expiration_time = [expiration_time] * len(keys)
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
-        expirations = [expirations] * len(keys) if isinstance(expirations, DHTExpiration) else expirations
-        keys, values, expirations, in_cache = map(list, [keys, values, expirations, in_cache])
-        assert len(keys) == len(values) == len(expirations) == len(in_cache), "Data is not aligned"
+        keys, values, expiration_time, in_cache = map(list, [keys, values, expiration_time, in_cache])
+        assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
-                                             expiration=expirations, in_cache=in_cache, peer=self.node_info)
+                                             expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
             async with self.rpc_semaphore:
                 response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
@@ -152,10 +153,10 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
-        assert len(request.keys) == len(request.values) == len(request.expiration) == len(request.in_cache)
+        assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         for key_bytes, value_bytes, expiration_time, in_cache in zip(
-                request.keys, request.values, request.expiration, request.in_cache):
+                request.keys, request.values, request.expiration_time, request.in_cache):
             local_memory = self.cache if in_cache else self.storage
             response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
         return response
@@ -180,15 +181,16 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
-            assert len(response.values) == len(response.expiration) == len(response.nearest) == len(keys), \
-                "DHTProtocol: response is not aligned with keys"
+            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
+                "DHTProtocol: response is not aligned with keys and/or expiration times"
 
             output = {}  # unpack data without special NOT_FOUND_* values
-            for key, value, expiration, nearest in zip(keys, response.values, response.expiration, response.nearest):
+            for key, value, expiration_time, nearest in zip(
+                    keys, response.values, response.expiration_time, response.nearest):
                 value = value if value != _NOT_FOUND_VALUE else None
-                expiration = expiration if expiration != _NOT_FOUND_EXPIRATION else None
+                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
                 nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
-                output[key] = (value, expiration, nearest)
+                output[key] = (value, expiration_time, nearest)
             return output
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -202,12 +204,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
 
-        response = dht_pb2.FindResponse(values=[], expiration=[], nearest=[], peer=self.node_info)
+        response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info)
         for key_id in map(DHTID.from_bytes, request.keys):
-            maybe_value, maybe_expiration = self.storage.get(key_id)
-            cached_value, cached_expiration = self.cache.get(key_id)
-            if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
-                maybe_value, maybe_expiration = cached_value, cached_expiration
+            maybe_value, maybe_expiration_time = self.storage.get(key_id)
+            cached_value, cached_expiration_time = self.cache.get(key_id)
+            if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
+                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
 
             nearest_neighbors = self.routing_table.get_nearest_neighbors(
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
@@ -217,7 +219,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_ids, endpoints = [], []
 
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
-            response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
+            response.expiration_time.append(maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
             response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
         return response
 
@@ -235,7 +237,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
-                for key, value, expiration in list(self.storage.items()):
+                for key, value, expiration_time in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
                         nearest_distance = neighbors[0][0].xor_distance(key)
@@ -243,7 +245,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        data_to_send.append((key, value, expiration))
+                        data_to_send.append((key, value, expiration_time))
                 if data_to_send:
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 
@@ -262,7 +264,7 @@ _NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values
 
 
 class LocalStorage:
-    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
+    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration_time) """
 
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
@@ -306,4 +308,4 @@ class LocalStorage:
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self.remove_outdated()
-        return ((key, value, expiration) for key, (value, expiration) in self.data.items())
+        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())

+ 19 - 13
hivemind/dht/traverse.py

@@ -215,16 +215,22 @@ async def traverse_dht(
             active_workers.subtract(queries_to_call)
             heap_updated_event.set()
 
-    # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
-    await asyncio.wait([asyncio.create_task(worker()) for _ in range(num_workers)],
-                       return_when=asyncio.FIRST_COMPLETED)  # first worker finishes when the search is over
-    assert len(unfinished_queries) == 0 and search_finished_event.is_set()
-
-    if await_all_tasks:
-        await asyncio.gather(*pending_tasks)
-
-    nearest_neighbors_per_query = {
-        query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
-        for query in queries
-    }
-    return nearest_neighbors_per_query, visited_nodes
+    workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
+    try:
+        # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
+        await asyncio.wait(workers, return_when=asyncio.FIRST_COMPLETED)
+        assert len(unfinished_queries) == 0 and search_finished_event.is_set()
+
+        if await_all_tasks:
+            await asyncio.gather(*pending_tasks)
+
+        nearest_neighbors_per_query = {
+            query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
+            for query in queries
+        }
+        return nearest_neighbors_per_query, visited_nodes
+
+    except asyncio.CancelledError as e:
+        for worker in workers:
+            worker.cancel()
+        raise e

+ 2 - 4
hivemind/server/connection_handler.py

@@ -59,13 +59,11 @@ class ConnectionHandler(mp.Process):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        response = await future.async_result()
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        response = await future.async_result()
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 116 - 99
hivemind/utils/mpfuture.py

@@ -1,21 +1,25 @@
+from __future__ import annotations
+import time
 import multiprocessing as mp
 import multiprocessing.connection
-from concurrent.futures import Future, CancelledError
-from warnings import warn
+import concurrent.futures._base as base
+
 import asyncio
+from functools import lru_cache
+from typing import Optional
+
+from hivemind.utils.threading import run_in_background
 
 
-class MPFuture(Future):
-    """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
-    STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
-    STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES
+class MPFuture(base.Future):
+    """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
+
+    TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
 
     def __init__(self, connection: mp.connection.Connection):
         """ manually create MPFuture. Please use MPFuture.make_pair instead """
+        self._state, self._result, self._exception = base.PENDING, None, None
         self.connection = connection
-        self.state = self.STATE_PENDING
-        self._result = None
-        self._exception = None
 
     @classmethod
     def make_pair(cls):
@@ -23,125 +27,138 @@ class MPFuture(Future):
         connection1, connection2 = mp.Pipe()
         return cls(connection1), cls(connection2)
 
-    def poll_and_recv(self, timeout):
-        available = self.connection.poll(timeout)
-        if not available:
-            raise TimeoutError
-        try:
-            status, payload = self.connection.recv()
-            self.connection.close()
-        except BrokenPipeError as e:
-            status, payload = self.STATE_EXCEPTION, e
-        return status, payload
-
-    def _recv(self, timeout):
-
-        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            status, payload = self.poll_and_recv(timeout)
-
-            assert status in self.STATES
-            self.state = status
-
-            if status == self.STATE_FINISHED:
-                self._result = payload
-            elif status == self.STATE_EXCEPTION:
-                self._exception = payload
-            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
-                pass  # only update self.state
-            else:
-                raise ValueError("Result status should not be self.STATE_PENDING")
-
-    def set_result(self, result):
+    def _send_updates(self):
+        """ Send updates to a paired MPFuture """
         try:
-            self.state, self._result = self.STATE_FINISHED, result
-            self.connection.send((self.STATE_FINISHED, result))
-            self.connection.close()
+            self.connection.send((self._state, self._result, self._exception))
+            if self._state in self.TERMINAL_STATES:
+                self._shutdown_trigger.set_result(True)
+                self.connection.close()
             return True
         except BrokenPipeError:
             return False
 
-    def set_exception(self, exception: BaseException):
+    def _recv_updates(self, timeout: Optional[float]):
+        """ Await updates from a paired MPFuture """
         try:
-            self.state, self._exception = self.STATE_EXCEPTION, exception
-            self.connection.send((self.STATE_EXCEPTION, exception))
-            self.connection.close()
-            return True
-        except BrokenPipeError:
-            return False
+            future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
+                               return_when=base.FIRST_COMPLETED)[0].pop()
+            if future is self._shutdown_trigger:
+                raise BrokenPipeError()
+            if not future.result():
+                raise TimeoutError()
+            self._state, result, exception = self.connection.recv()
+            self._result = result if result is not None else self._result
+            self._exception = exception if exception is not None else self._exception
+            if self._state in self.TERMINAL_STATES:
+                self.connection.close()
+        except TimeoutError as e:
+            raise e
+        except (BrokenPipeError, OSError) as e:
+            if self._state in (base.PENDING, base.RUNNING):
+                self._state, self._exception = base.FINISHED, e
+
+    def _await_terminal_state(self, timeout: Optional[float]):
+        """ Await updates until future is either finished, cancelled or got an exception """
+        time_left = float('inf') if timeout is None else timeout
+        time_before = time.monotonic()
+        while self._state not in self.TERMINAL_STATES and time_left > 0:
+            self._recv_updates(time_left if timeout else None)
+            time_spent = time.monotonic() - time_before
+            time_left, time_before = time_left - time_spent, time_before + time_spent
+
+    def _sync_updates(self):
+        """ Apply queued updates from a paired MPFuture without waiting for new ones """
+        try:
+            self._recv_updates(timeout=0)
+        except TimeoutError:
+            pass
+
+    def set_result(self, result):
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
+        self._state, self._result = base.FINISHED, result
+        return self._send_updates()
+
+    def set_exception(self, exception: BaseException):
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            raise RuntimeError(f"Can't set_exception to a future that is in {self._state}")
+        self._state, self._exception = base.FINISHED, exception
+        self._send_updates()
 
     def set_running_or_notify_cancel(self):
-        return True
+        self._sync_updates()
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return self._send_updates()
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise RuntimeError(f"Can't set_running_or_notify_cancel to a future that is in {self._state}")
 
     def cancel(self):
-        raise NotImplementedError()
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            return False
+        self._state, self._exception = base.CANCELLED, base.CancelledError()
+        return self._send_updates()
 
     def result(self, timeout=None):
-        self._recv(timeout)
-        if self.state == self.STATE_FINISHED:
-            return self._result
-        elif self.state == self.STATE_EXCEPTION:
+        self._await_terminal_state(timeout)
+        if self._exception is not None:
             raise self._exception
-        else:
-            assert self.state == self.STATE_CANCELLED
-            raise CancelledError()
+        return self._result
 
     def exception(self, timeout=None):
-        self._recv(timeout)
+        self._await_terminal_state(timeout)
+        if self._state == base.CANCELLED:
+            raise base.CancelledError()
         return self._exception
 
     def done(self):
-        return self.state in (self.STATE_FINISHED, self.STATE_EXCEPTION, self.STATE_CANCELLED)
+        self._sync_updates()
+        return self._state in self.TERMINAL_STATES
 
     def running(self):
-        return self.state == self.STATE_RUNNING
+        self._sync_updates()
+        return self._state == base.RUNNING
 
     def cancelled(self):
-        warn("cancelled not implemented")
-        return False
+        self._sync_updates()
+        return self._state == base.CANCELLED
 
     def add_done_callback(self, callback):
-        raise NotImplementedError()
-
-    def __repr__(self):
-        try:
-            self._recv(timeout=0)
-        except TimeoutError:
-            pass
-        if self.state == self.STATE_FINISHED:
-            return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
-        elif self.state == self.STATE_EXCEPTION:
-            return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
-        else:
-            return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)
+        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
 
-    async def _async_recv(self, timeout):
-        loop = asyncio.get_running_loop()
+    def remove_done_callback(self, callback):
+        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
 
-        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            status, payload = await loop.run_in_executor(None, self.poll_and_recv, timeout)
+    def get_loop(self):
+        raise NotImplementedError(f"MPFuture doesn't support get_loop")
 
-            assert status in self.STATES
-            self.state = status
+    @property
+    @lru_cache()
+    def _shutdown_trigger(self):
+        return base.Future()
 
-            if status == self.STATE_FINISHED:
-                self._result = payload
-            elif status == self.STATE_EXCEPTION:
-                self._exception = payload
-            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
-                pass  # only update self.state
+    def __repr__(self):
+        self._sync_updates()
+        if self._state == base.FINISHED:
+            if self._exception:
+                return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
             else:
-                raise ValueError("Result status should not be self.STATE_PENDING")
+                return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
+        else:
+            return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
 
-    async def async_result(self, timeout=None):
-        await self._async_recv(timeout)
-        if self.state == self.STATE_FINISHED:
-            return self._result
-        elif self.state == self.STATE_EXCEPTION:
+    def __await__(self):
+        yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
+        if self._exception:
             raise self._exception
-        else:
-            assert self.state == self.STATE_CANCELLED
-            raise CancelledError()
+        return self._result
 
-    async def async_exception(self, timeout=None):
-        await self._async_recv(timeout)
-        return self._exception
+    def __del__(self):
+        self._shutdown_trigger.set_result(True)
+        self.connection.close()

+ 6 - 13
hivemind/utils/threading.py

@@ -1,27 +1,20 @@
 import os
-from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
+from concurrent.futures import Future, as_completed, TimeoutError, ThreadPoolExecutor
 import time
 from typing import Optional, List
 
-GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 
 
 def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
-
+    global EXECUTOR_PID, GLOBAL_EXECUTOR
+    if os.getpid() != EXECUTOR_PID:
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+        EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 
 
-def run_forever(func: callable, *args, **kwargs):
-    """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
-
-    def repeat():
-        while True:
-            func(*args, **kwargs)
-
-    return run_in_background(repeat)
-
-
 def run_and_await_k(jobs: List[callable], k: int,
                     timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
     """

+ 6 - 7
tests/benchmark_dht.py

@@ -15,15 +15,15 @@ def random_endpoint() -> hivemind.Endpoint:
 
 
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
-                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration_time: float):
-    old_expiration_time, hivemind.DHT.EXPIRATION = hivemind.DHT.EXPIRATION, expiration_time
+                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
     random.seed(random_seed)
 
     print("Creating peers...")
     peers = []
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
+        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
+                            expiration=expiration, listen_on=f'0.0.0.0:*')
         peers.append(peer)
 
     store_peer, get_peer = peers[-2:]
@@ -52,7 +52,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
 
-    if time.perf_counter() - benchmark_started > expiration_time:
+    if time.perf_counter() - benchmark_started > expiration:
         warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
 
     successful_gets = total_get_time = 0
@@ -67,7 +67,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                     and expert.endpoint == endpoints[start // expert_batch_size]:
                 successful_gets += 1
 
-    if time.perf_counter() - benchmark_started > expiration_time:
+    if time.perf_counter() - benchmark_started > expiration:
         warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
 
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
@@ -75,7 +75,6 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
     alive_peers = [peer.is_alive() for peer in peers]
     print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
-    hivemind.DHT.EXPIRATION = old_expiration_time
 
 
 if __name__ == "__main__":
@@ -84,7 +83,7 @@ if __name__ == "__main__":
     parser.add_argument('--initial_peers', type=int, default=1, required=False)
     parser.add_argument('--num_experts', type=int, default=256, required=False)
     parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
-    parser.add_argument('--expiration_time', type=float, default=300, required=False)
+    parser.add_argument('--expiration', type=float, default=300, required=False)
     parser.add_argument('--wait_after_request', type=float, default=0, required=False)
     parser.add_argument('--wait_before_read', type=float, default=0, required=False)
     parser.add_argument('--wait_timeout', type=float, default=5, required=False)

+ 2 - 2
tests/test_dht.py

@@ -5,7 +5,7 @@ import random
 import heapq
 import uuid
 from itertools import chain
-from typing import Optional, Tuple
+from typing import Optional
 import numpy as np
 
 import hivemind
@@ -249,7 +249,7 @@ def test_dht_node():
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
         values = 3, 2, 'batman', [1, 2, 3]
-        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration=get_dht_time() + 999))
+        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
         assert all(store_ok.values()), "failed to store one or more keys"
         response = loop.run_until_complete(me.get_many(keys[::-1]))
         for key, value in zip(keys, values):

+ 116 - 0
tests/test_util_modules.py

@@ -0,0 +1,116 @@
+import asyncio
+
+import pytest
+import hivemind
+
+from concurrent.futures import CancelledError
+
+
+def test_mpfuture_result():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    f1.set_result(321)
+    assert f2.result() == 321
+    assert f1.result() == 321
+
+    for future in [f1, f2]:
+        with pytest.raises(RuntimeError):
+            future.set_result(123)
+        with pytest.raises(RuntimeError):
+            future.set_exception(ValueError())
+        assert future.cancel() is False
+        assert future.done() and not future.running() and not future.cancelled()
+
+    f1, f2 = hivemind.MPFuture.make_pair()
+    with pytest.raises(TimeoutError):
+        f1.result(timeout=1e-3)
+
+    f2.set_result(['abacaba', 123])
+    assert f1.result() == ['abacaba', 123]
+
+
+def test_mpfuture_exception():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    with pytest.raises(TimeoutError):
+        f1.exception(timeout=1e-3)
+
+    f2.set_exception(NotImplementedError())
+
+    for future in [f1, f2]:
+        assert isinstance(future.exception(), NotImplementedError)
+        with pytest.raises(NotImplementedError):
+            future.result()
+        assert future.cancel() is False
+        assert future.done() and not future.running() and not future.cancelled()
+
+
+def test_mpfuture_cancel():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    assert not f2.cancelled()
+    f1.cancel()
+    for future in [f1, f2]:
+        with pytest.raises(CancelledError):
+            future.result()
+        with pytest.raises(CancelledError):
+            future.exception()
+        with pytest.raises(RuntimeError):
+            future.set_result(123)
+        with pytest.raises(RuntimeError):
+            future.set_exception(NotImplementedError)
+        assert future.cancelled() and future.done() and not future.running()
+
+
+def test_mpfuture_status():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    assert f1.set_running_or_notify_cancel() is True
+    for future in [f1, f2]:
+        assert future.running() and not future.done() and not future.cancelled()
+        with pytest.raises(RuntimeError):
+            future.set_running_or_notify_cancel()
+    f2.cancel()
+    for future in [f1, f2]:
+        assert not future.running() and future.done() and future.cancelled()
+        assert future.set_running_or_notify_cancel() is False
+
+    f1, f2 = hivemind.MPFuture.make_pair()
+    f1.cancel()
+    for future in [f1, f2]:
+        assert future.set_running_or_notify_cancel() is False
+
+
+def test_await_mpfuture():
+    async def _run():
+        # await result
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_assign():
+            assert f2.set_running_or_notify_cancel() is True
+            await asyncio.sleep(0.1)
+            f2.set_result((123, 'ololo'))
+
+        asyncio.create_task(wait_and_assign())
+        for future in [f1, f2]:
+            res = await future
+            assert res == (123, 'ololo')
+
+        # await cancel
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_cancel():
+            await asyncio.sleep(0.1)
+            f1.cancel()
+
+        asyncio.create_task(wait_and_cancel())
+        for future in [f1, f2]:
+            with pytest.raises(CancelledError):
+                await future
+
+        # await exception
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_raise():
+            await asyncio.sleep(0.1)
+            f1.set_exception(SystemError())
+
+        asyncio.create_task(wait_and_raise())
+        for future in [f1, f2]:
+            with pytest.raises(SystemError):
+                await future
+
+    asyncio.new_event_loop().run_until_complete(_run())

+ 9 - 11
tests/test_utils/run_server.py

@@ -14,7 +14,7 @@ from test_utils.layers import name_to_block, name_to_input
 def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
-                      UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
+                      start=False, **kwargs) -> hivemind.Server:
     """
     Instantiate a server with several identical experts. See argparse comments below for details
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -47,9 +47,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     if not no_dht:
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
-            dht_root = hivemind.DHT(initial_peers=initial_peers,
-                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}",
-                                    start=True)
+            dht_root = hivemind.DHT(initial_peers=initial_peers, start=True,
+                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}")
             print(f"Initializing DHT with port {dht_root.port}")
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
@@ -57,9 +56,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
-        dht = hivemind.DHT(initial_peers=initial_peers,
-                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}",
-                           start=True)
+        dht = hivemind.DHT(initial_peers=initial_peers, start=True,
+                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
         if verbose:
             print(f"Running dht node on port {dht.port}")
 
@@ -74,7 +72,7 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     for i in range(num_experts):
         expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
-        expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
+        expert_uid = f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
@@ -154,12 +152,12 @@ if __name__ == '__main__':
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
-                                                                                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
+                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
-                                                                                    ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
+                        ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
-                                                                           ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
 
     args = vars(parser.parse_args())