浏览代码

Simplify & explain hivemind.dht.DHT (#78)

* unused import

* remove loop.run_forever

* * explain how DHT stores stuff (hivemind/dht/__init__.py)
* implement more efficient DHT.first_k_active (use get_many)
* change uid_delimiter into a DHT instance property

* hivemind.DHT is now thread-safe on client side

* implement all mpfuture methods

* get_experts: allow returning future

* rollback run_in_background

* switch back to GLOBAL_EXECUTOR (tests should fail)

* instantiate executor in each process to avoid os locks

* rollback to minimize diff

* typo

* traverse_dht can now be cancelled

* node.store_many and node.get_many can now be cancelled

* address review by mryab@

* address review by mryab@

* address review by mryab@

* address review by mryab@

* address review by mryab@
justheuristic 5 年之前
父节点
当前提交
535318e249

+ 1 - 1
.circleci/config.yml

@@ -27,7 +27,7 @@ jobs:
           command: sudo python setup.py develop
           command: sudo python setup.py develop
           name: setup
           name: setup
       - run:
       - run:
-          command: for test_file in tests/test*.py; do pytest $test_file --full-trace; done
+          command: pytest ./tests
           name: tests
           name: tests
       - run:
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic
           command: python tests/benchmark_throughput.py --preset minimalistic

+ 4 - 4
hivemind/client/moe.py

@@ -104,7 +104,7 @@ class RemoteMixtureOfExperts(nn.Module):
         beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
         beam = np.array([[self.uid_prefix]] * batch_size, dtype=object)  # [batch_size, up_to_beam_size]
         scores = np.zeros([batch_size, 1], dtype=np.float64)
         scores = np.zeros([batch_size, 1], dtype=np.float64)
 
 
-        delimeters = np.array(self.dht.UID_DELIMETER)[None, None, None]  # pre-compute numpy array for fast concat
+        delimiters = np.array(self.dht.UID_DELIMITER)[None, None, None]  # pre-compute numpy array for fast concat
 
 
         for dim_index, dim_scores in enumerate(grid_scores):
         for dim_index, dim_scores in enumerate(grid_scores):
             dim_scores = dim_scores.detach().cpu().numpy()
             dim_scores = dim_scores.detach().cpu().numpy()
@@ -112,7 +112,7 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
             # create all possible successsors from current beam
             # create all possible successsors from current beam
             dim_indices = np.arange(dim_scores.shape[1]).astype(str)
             dim_indices = np.arange(dim_scores.shape[1]).astype(str)
-            new_candidates = beam[:, :, None] + delimeters + dim_indices[None, None, :]
+            new_candidates = beam[:, :, None] + delimiters + dim_indices[None, None, :]
             new_candidates = new_candidates.reshape([batch_size, -1])
             new_candidates = new_candidates.reshape([batch_size, -1])
 
 
             new_scores = scores[:, :, None] + dim_scores[:, None, :]
             new_scores = scores[:, :, None] + dim_scores[:, None, :]
@@ -166,8 +166,8 @@ class RemoteMixtureOfExperts(nn.Module):
 
 
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         grid_indices = np.zeros([len(flat_experts), len(grid_scores)], dtype=np.int64)
         for i, expert in enumerate(flat_experts):
         for i, expert in enumerate(flat_experts):
-            expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMETER):]
-            expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMETER)))
+            expert_indices = expert.uid[len(self.uid_prefix) + len(self.dht.UID_DELIMITER):]
+            expert_indices = list(map(int, expert_indices.split(self.dht.UID_DELIMITER)))
             grid_indices[i] = expert_indices
             grid_indices[i] = expert_indices
 
 
         scores_per_dim = [
         scores_per_dim = [

+ 113 - 85
hivemind/dht/__init__.py

@@ -16,6 +16,8 @@ import asyncio
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
+from collections import deque
+from concurrent.futures import ThreadPoolExecutor
 from typing import List, Optional, Sequence
 from typing import List, Optional, Sequence
 
 
 import uvloop
 import uvloop
@@ -23,12 +25,12 @@ import uvloop
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time
 from hivemind.dht.routing import get_dht_time
-from hivemind.utils import MPFuture, Endpoint, run_in_background
+from hivemind.utils import MPFuture, Endpoint
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):
     """
     """
-    A high-level interface to hivemind DHT. Runs a dht node in a background process.
+    High-level interface to hivemind.dht that is designed to allow RemoteMixtureOfExperts to select best experts.
 
 
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param initial_peers: one or multiple endpoints pointing to active DHT peers. Similar format to listen_on.
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
@@ -36,19 +38,45 @@ class DHT(mp.Process):
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
     :param max_workers: declare_experts and get_experts will use up to this many parallel workers
         (but no more than one per key)
         (but no more than one per key)
+    :param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
+    :param receiver_threads: uses this many threads to await on input pipe. Default = 1 should be enough in most cases
     :param kwargs: any other params will be forwarded to DHTNode upon creation
     :param kwargs: any other params will be forwarded to DHTNode upon creation
+
+    Each expert has an identifier in the form of {prefix}.{i}.{j}.{...}, e.g. "ffn_expert.98.76.54.32.10"
+    An expert identifier consists of:
+
+        * optional prefix that determines expert role, experiment name, etc.
+        * one or more integers that determine that expert's position in an N-dimensional grid
+
+    A hivemind.Server can ``DHT.declare_experts(expert_uids: List[str])`` to make its experts visible to everyone.
+    When declaring experts, DHT will store each expert's uid and all its prefixes until :expiration: (specified at init)
+    For instance, declaring "ffn_expert.98.76.54.32.10" will store the following keys in a DHT:
+    ``"ffn_expert", "ffn_expert.98", "ffn_expert.98.76", ..., "ffn_expert.98.76.54.32.10"``
+
+    RemoteMixtureOfExperts can use these prefixes to find top-k most suitable experts with a left-to-right beam search.
+    For instance, consider RemoteMixtureOfExperts with prefix "ffn_expert" and grid size [100, 100, 100, 100, 100].
+    This MoE can query all experts with that prefix and arbitrary indices in 0...99 along each dimension.
+    However, not every expert in such 100^5 grid can be alive at a given moment of time (the grid size is redundant).
+    In order to find k best "alive" experts, MoE first ranks indices along the first dimension with its gating function.
+    It can then check which of those indices correspond to "alive" experts by querying keys such as "ffn_expert.98".
+    This is done using DHT.first_k_active function. After selecting k best indices along first dimension, MoE moves
+    to the second dimension. It can find top-k pairs of indices (e.g. "expert.98.76") that start with one of k first
+    indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
+    This beam search explores one additional dimension per step and finds k best experts from across the DHT
+    in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
     """
     """
-    UID_DELIMETER = '.'  # splits expert uids over this delimeter
-    EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
-    make_key = "{}::{}".format
+
+    UID_DELIMITER = '.'  # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
+    #  formally, prefixes = {uid.split(UID_DELIMITER)[:length] for length in range(1, uid.count(UID_DELIMITER) + 2)}
 
 
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
     def __init__(self, listen_on: Endpoint = "0.0.0.0:*", initial_peers: Sequence[Endpoint] = (), *, start: bool,
-                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
+                 daemon: bool = True, max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None,
+                 receiver_threads: int = 1, expiration: float = 300, **kwargs):
         super().__init__()
         super().__init__()
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
         self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
-        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+        self.receiver_threads, self.max_workers, self.parallel_rpc = receiver_threads, max_workers, parallel_rpc
+        self.expiration = expiration
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
         self._port = mp.Value(ctypes.c_int32, 0)  # initialized after dht starts
-        self.node: Optional[DHTNode] = None  # initialized inside self.run only
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.ready = mp.Event()
         self.daemon = daemon
         self.daemon = daemon
@@ -62,16 +90,20 @@ class DHT(mp.Process):
         uvloop.install()
         uvloop.install()
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
         asyncio.set_event_loop(loop)
-        self.node: DHTNode = loop.run_until_complete(DHTNode.create(
-            initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
-            num_workers=self.max_workers or 1, **self.kwargs))
-        self._port.value = self.node.port
-        run_in_background(loop.run_forever)
-        self.ready.set()
+        pipe_awaiter = ThreadPoolExecutor(self.receiver_threads)
+
+        async def _run():
+            node = await DHTNode.create(
+                initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+                num_workers=self.max_workers or 1, **self.kwargs)
+            self._port.value = node.port
+            self.ready.set()
+
+            while True:
+                method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._pipe.recv)
+                asyncio.create_task(getattr(self, method)(node, *args, **kwargs))
 
 
-        while True:
-            method, args, kwargs = self._pipe.recv()
-            getattr(self, method)(*args, **kwargs)
+        loop.run_until_complete(_run())
 
 
     def run_in_background(self, await_ready=True, timeout=None):
     def run_in_background(self, await_ready=True, timeout=None):
         """
         """
@@ -85,7 +117,7 @@ class DHT(mp.Process):
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """ Shuts down the dht process """
         """ Shuts down the dht process """
         if self.is_alive():
         if self.is_alive():
-            self.kill()
+            self.terminate()
         else:
         else:
             warnings.warn("DHT shutdown has no effect: dht process is already not alive")
             warnings.warn("DHT shutdown has no effect: dht process is already not alive")
 
 
@@ -93,32 +125,27 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
-    def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
+    def get_experts(self, uids: List[str], expiration_time: Optional[DHTExpiration] = None,
+                    wait=True) -> List[Optional[RemoteExpert]]:
         """
         """
         :param uids: find experts with these ids from across the DHT
         :param uids: find experts with these ids from across the DHT
-        :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
+        :param expiration_time: if specified, return experts that expire no sooner than this (based on get_dht_time)
+        :param wait: if True (default), return when experts are returned. Otherwise return a Future.
         :returns: a list of [RemoteExpert if found else None]
         :returns: a list of [RemoteExpert if found else None]
         """
         """
+        assert not isinstance(uids, str), "Please send a list / tuple of expert uids."
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
-        return future.result()
+        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration_time=expiration_time, future=_future)))
+        return future.result() if wait else future
 
 
-    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: MPFuture):
-        loop = asyncio.get_event_loop()
-        expiration = expiration or get_dht_time()
+    async def _get_experts(
+            self, node: DHTNode, uids: List[str], expiration_time: Optional[DHTExpiration], future: MPFuture):
+        if expiration_time is None:
+            expiration_time = get_dht_time()
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        keys = [self.make_key('expert', uid) for uid in uids]
-
-        response = asyncio.run_coroutine_threadsafe(
-            self.node.get_many(keys, expiration, num_workers=num_workers), loop).result()
-
-        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
-        for i, (key, uid) in enumerate(zip(keys, uids)):
-            maybe_endpoint, maybe_expiration = response[key]
-            if maybe_expiration is not None:  # if we found a value
-                experts[i] = RemoteExpert(uid=uid, endpoint=maybe_endpoint)
-
-        future.set_result(experts)
+        response = await node.get_many(uids, expiration_time, num_workers=num_workers)
+        future.set_result([RemoteExpert(uid, maybe_endpoint) if maybe_expiration_time else None
+                           for uid, (maybe_endpoint, maybe_expiration_time) in response.items()])
 
 
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
     def declare_experts(self, uids: List[str], endpoint: Endpoint, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         """
@@ -136,69 +163,70 @@ class DHT(mp.Process):
         if wait:
         if wait:
             return future.result(timeout)
             return future.result(timeout)
 
 
-    def _declare_experts(self, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
-        assert self.node is not None, "This method should only be accessed from inside .run method"
+    async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpoint, future: Optional[MPFuture]):
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
-        loop = asyncio.get_event_loop()
-        expiration_time = get_dht_time() + self.EXPIRATION
-        unique_prefixes = set()
+        expiration_time = get_dht_time() + self.expiration
 
 
-        keys, values = [], []
+        data_to_store = {}
         for uid in uids:
         for uid in uids:
-            uid_parts = uid.split(self.UID_DELIMETER)
-            keys.append(self.make_key('expert', uid))
-            values.append(endpoint)
-            unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
-
-        for prefix in unique_prefixes:
-            keys.append(self.make_key('prefix', prefix))
-            values.append(True)
-
-        store_ok = asyncio.run_coroutine_threadsafe(
-            self.node.store_many(keys, values, expiration_time, num_workers=num_workers), loop
-        ).result()
+            uid_parts = uid.split(self.UID_DELIMITER)
+            for i in range(len(uid_parts)):
+                uid_prefix_i = self.UID_DELIMITER.join(uid_parts[:i + 1])
+                data_to_store[uid_prefix_i] = endpoint
+
+        store_keys, store_values = zip(*data_to_store.items())
+        store_ok = await node.store_many(store_keys, store_values, expiration_time, num_workers=num_workers)
         if future is not None:
         if future is not None:
-            future.set_result([store_ok[key] for key in keys])
+            future.set_result([store_ok[key] for key in data_to_store.keys()])
 
 
-    def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
+    def first_k_active(self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None):
         """
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
 
 
-        :param prefixes: a list of uid prefixes ordered from highest to lowest priority
+        :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param k: return at most *this many* active prefixes
         :param k: return at most *this many* active prefixes
-        :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
+        :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
+        :param chunk_size: dispatch this many requests in one task
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         """
         """
-        assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
+        assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
         future, _future = MPFuture.make_pair()
         future, _future = MPFuture.make_pair()
         self.pipe.send(('_first_k_active', [],
         self.pipe.send(('_first_k_active', [],
-                        dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
+                        dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
+                             chunk_size=chunk_size or k, future=_future)))
         return future.result()
         return future.result()
 
 
-    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: MPFuture):
-        assert self.node is not None, "This method should only be accessed from inside .run method"
-        max_prefetch = max_prefetch or len(prefixes)
-        loop = asyncio.get_event_loop()
-        lookup_prefetch = [asyncio.run_coroutine_threadsafe(self.node.get(self.make_key('prefix', prefix)), loop)
-                           for prefix in prefixes[:max_prefetch]]
+    async def _first_k_active(
+            self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
+        num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
+        total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
         active_prefixes = []
         active_prefixes = []
 
 
-        for i, prefix in enumerate(prefixes):
-            _, maybe_expiration = lookup_prefetch[i].result()
-
-            if maybe_expiration is not None:
-                active_prefixes.append(prefix)
-                if len(active_prefixes) >= k:
-                    future.set_result(active_prefixes)
-                    for task in lookup_prefetch[i:]:
-                        task.cancel()
-                    return
-
-            # pre-dispatch the next request in line
-            if len(lookup_prefetch) < len(prefixes):
-                lookup_prefetch.append(
-                    asyncio.run_coroutine_threadsafe(
-                        self.node.get(self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop))
-
-        # could not find enough active prefixes; return what we can
+        pending_tasks = deque(
+            asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
+                                              num_workers=num_workers_per_chunk))
+            for chunk_i in range(min(max_prefetch + 1, total_chunks))
+        )  # pre-dispatch first task and up to max_prefetch additional tasks
+
+        for chunk_i in range(total_chunks):
+            # parse task results in chronological order, launch additional tasks on demand
+            response = await pending_tasks.popleft()
+            for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
+                if response[uid_prefix][1] is not None:  # found active peer
+                    active_prefixes.append(uid_prefix)
+                    # if we found enough active experts, finish immediately
+                    if len(active_prefixes) >= k:
+                        break
+            if len(active_prefixes) >= k:
+                for task in pending_tasks:
+                    task.cancel()
+                break
+
+            pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
+            if pre_dispatch_chunk_i < total_chunks:
+                pending_tasks.append(asyncio.create_task(node.get_many(
+                    uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
+                    num_workers=num_workers_per_chunk)))
+
+        # return k active prefixes or as many as we could find
         future.set_result(active_prefixes)
         future.set_result(active_prefixes)

+ 17 - 17
hivemind/dht/dht.proto

@@ -17,39 +17,39 @@ service DHT {
 message NodeInfo {
 message NodeInfo {
   // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
   // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
-  bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
-  int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+  bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
+  int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
 }
 }
 
 
 message StoreRequest {
 message StoreRequest {
   // three lists of the same length representing dht keys, dht values and expiration
   // three lists of the same length representing dht keys, dht values and expiration
-  repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
-  repeated bytes values = 2;        // binary-encoded value for i-th key
-  repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
-  repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
-  NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  repeated bytes keys = 1;             // keys in the form of DHTID.generate(raw_key).to_bytes()
+  repeated bytes values = 2;           // binary-encoded value for i-th key
+  repeated double expiration_time = 3; // expirations for i-th key (type = DHTExpiration)
+  repeated bool in_cache = 4;          // if in_cache[i], store i-th key in cache, else store normally
+  NodeInfo peer = 5;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
 }
 }
 
 
 message StoreResponse {
 message StoreResponse {
-  repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
-  NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+  repeated bool store_ok = 1;          // for every key, True means store accepted, False means store rejected/failed
+  NodeInfo peer = 2;                   // respondent's node id, for you to update routing table
 }
 }
 
 
 message FindRequest {
 message FindRequest {
-  repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
-  NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+  repeated bytes keys = 1;             // a list of DHTID search keys encoded as bytes
+  NodeInfo peer = 2;                   // optional, same behavior as in DHT.ping
 }
 }
 
 
 message Peers {
 message Peers {
   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
-  repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
-  repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+  repeated bytes node_ids = 1;         // DHTID serialized with node_id.to_bytes()
+  repeated string endpoints = 2;       // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
 }
 }
 
 
 message FindResponse {
 message FindResponse {
-  repeated bytes values = 1;        // value for i-th key, b'' means not found locally
-  repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
-  repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
-  NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+  repeated bytes values = 1;           // value for i-th key, b'' means not found locally
+  repeated double expiration_time = 2; // expiration time for i-th value, only valid value is found
+  repeated Peers nearest = 3;          // peers ordered from nearest to farthest based on distance to i-th key
+  NodeInfo peer = 4;                   // respondent's node id, for you to update routing table
 }
 }
 
 

+ 40 - 33
hivemind/dht/node.py

@@ -28,7 +28,7 @@ class DHTNode:
     Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
     Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
 
 
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
     * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
-    * store - send several (key, value, expiration) pairs to the same peer (like Kademlia STORE, but in bulk)
+    * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
     * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
     * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
         This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
@@ -37,10 +37,10 @@ class DHTNode:
 
 
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
     - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
       IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
-    - when requested to store(key: value, expiration), a node must store (key => value) at until expiration time
+    - when requested to store(key: value, expiration_time), a node must store (key => value) at until expiration time
       or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
       or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
       has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
       has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
-    - when requested to store(key: value, expiration, in_cache=True), stores (key => value) in a separate "cache".
+    - when requested to store(key: value, expiration_time, in_cache=True), stores (key => value) in a separate "cache".
       Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
       Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
 
 
     """
     """
@@ -191,15 +191,15 @@ class DHTNode:
         store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
         store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
         return store_ok[key]
         return store_ok[key]
 
 
-    async def store_many(
-            self, keys: List[DHTKey], values: List[DHTValue], expiration: Union[DHTExpiration, List[DHTExpiration]],
-            exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
+    async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
+                         expiration_time: Union[DHTExpiration, List[DHTExpiration]],
+                         exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
         """
         """
-        Traverse DHT to find up to best nodes to store multiple (key, value, expiration) pairs.
+        Traverse DHT to find up to best nodes to store multiple (key, value, expiration_time) pairs.
 
 
         :param keys: arbitrary serializable keys associated with each value
         :param keys: arbitrary serializable keys associated with each value
         :param values: serializable "payload" for each key
         :param values: serializable "payload" for each key
-        :param expiration: either one expiration time for all keys or individual expiration times (see class doc)
+        :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
         :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
@@ -207,13 +207,14 @@ class DHTNode:
             if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
             if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
         :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
-        expiration = [expiration] * len(keys) if isinstance(expiration, DHTExpiration) else expiration
-        assert len(keys) == len(values) == len(expiration), "Please provide equal number of keys, values and expiration"
+        if isinstance(expiration_time, DHTExpiration):
+            expiration_time = [expiration_time] * len(keys)
+        assert len(keys) == len(values) == len(expiration_time), "Number of keys, values and expiration doesn't match."
 
 
         key_ids = list(map(DHTID.generate, keys))
         key_ids = list(map(DHTID.generate, keys))
         id_to_original_key = dict(zip(key_ids, keys))
         id_to_original_key = dict(zip(key_ids, keys))
         binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
         binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
-        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration)}
+        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration_time)}
         unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
         unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
 
 
         store_ok = {key: False for key in keys}  # outputs, updated during search
         store_ok = {key: False for key in keys}  # outputs, updated during search
@@ -272,12 +273,16 @@ class DHTNode:
 
 
             store_finished_events[id_to_original_key[key_id]].set()
             store_finished_events[id_to_original_key[key_id]].set()
 
 
-        asyncio.create_task(self.find_nearest_nodes(
+        store_task = asyncio.create_task(self.find_nearest_nodes(
             queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
-        await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # await one (or all) store accepts
-        assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
-        return store_ok
+        try:
+            await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # wait for items to be stored
+            assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
+            return store_ok
+        except asyncio.CancelledError as e:
+            store_task.cancel()
+            raise e
 
 
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
     async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
         """
         """
@@ -316,17 +321,17 @@ class DHTNode:
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
         unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
         node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
 
 
-        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
+        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration_time", "source_node_id"])
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
         latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
 
 
         # stage 1: value can be stored in our local cache
         # stage 1: value can be stored in our local cache
         for key_id in key_ids:
         for key_id in key_ids:
-            maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
-            if maybe_expiration is None:
-                maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
-            if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
-                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, self.node_id)
-                if maybe_expiration >= sufficient_expiration_time:
+            maybe_value, maybe_expiration_time = self.protocol.storage.get(key_id)
+            if maybe_expiration_time is None:
+                maybe_value, maybe_expiration_time = self.protocol.cache.get(key_id)
+            if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
+                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id)
+                if maybe_expiration_time >= sufficient_expiration_time:
                     unfinished_key_ids.remove(key_id)
                     unfinished_key_ids.remove(key_id)
 
 
         # stage 2: traverse the DHT for any unfinished keys
         # stage 2: traverse the DHT for any unfinished keys
@@ -341,11 +346,11 @@ class DHTNode:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
             output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
-            for key_id, (maybe_value, maybe_expiration, peers) in response.items():
+            for key_id, (maybe_value, maybe_expiration_time, peers) in response.items():
                 node_to_endpoint.update(peers)
                 node_to_endpoint.update(peers)
-                if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
-                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
-                should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
+                if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
+                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, peer)
+                should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time)
                 output[key_id] = list(peers.keys()), should_interrupt
                 output[key_id] = list(peers.keys()), should_interrupt
             return output
             return output
 
 
@@ -356,10 +361,10 @@ class DHTNode:
 
 
         # stage 3: cache any new results depending on caching parameters
         # stage 3: cache any new results depending on caching parameters
         for key_id, nearest_nodes in nearest_nodes_per_query.items():
         for key_id, nearest_nodes in nearest_nodes_per_query.items():
-            latest_value_bytes, latest_expiration, latest_node_id = latest_results[key_id]
-            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it
+            latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[key_id]
+            should_cache = latest_expiration_time >= sufficient_expiration_time  # if we found a newer value, cache it
             if should_cache and self.cache_locally:
             if should_cache and self.cache_locally:
-                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
+                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time)
 
 
             if should_cache and self.cache_nearest:
             if should_cache and self.cache_nearest:
                 num_cached_nodes = 0
                 num_cached_nodes = 0
@@ -367,16 +372,18 @@ class DHTNode:
                     if node_id == latest_node_id:
                     if node_id == latest_node_id:
                         continue
                         continue
                     asyncio.create_task(self.protocol.call_store(
                     asyncio.create_task(self.protocol.call_store(
-                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
+                        node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time],
+                        in_cache=True))
                     num_cached_nodes += 1
                     num_cached_nodes += 1
                     if num_cached_nodes >= self.cache_nearest:
                     if num_cached_nodes >= self.cache_nearest:
                         break
                         break
 
 
         # stage 4: deserialize data and assemble function output
         # stage 4: deserialize data and assemble function output
         find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
         find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
-        for key_id, (latest_value_bytes, latest_expiration, _) in latest_results.items():
-            if latest_expiration != -float('inf'):
-                find_result[id_to_original_key[key_id]] = self.serializer.loads(latest_value_bytes), latest_expiration
+        for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items():
+            if latest_expiration_time != -float('inf'):
+                latest_value = self.serializer.loads(latest_value_bytes)
+                find_result[id_to_original_key[key_id]] = (latest_value, latest_expiration_time)
             else:
             else:
                 find_result[id_to_original_key[key_id]] = None, None
                 find_result[id_to_original_key[key_id]] = None, None
         return find_result
         return find_result

+ 26 - 24
hivemind/dht/protocol.py

@@ -113,15 +113,15 @@ class DHTProtocol(dht_grpc.DHTServicer):
         return self.node_info
         return self.node_info
 
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
-                         expirations: Union[DHTExpiration, Sequence[DHTExpiration]],
+                         expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
                          in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
                          in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
         """
         """
-        Ask a recipient to store several (key, value : expiration) items or update their older value
+        Ask a recipient to store several (key, value : expiration_time) items or update their older value
 
 
         :param peer: request this peer to store the data
         :param peer: request this peer to store the data
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
         :param values: a list of N serialized values (bytes) for each respective key
         :param values: a list of N serialized values (bytes) for each respective key
-        :param expirations: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param expiration_time: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
         :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
          until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
@@ -129,13 +129,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
         :return: list of [True / False] True = stored, False = failed (found newer value or no response)
          if peer did not respond (e.g. due to timeout or congestion), returns None
          if peer did not respond (e.g. due to timeout or congestion), returns None
         """
         """
+        if isinstance(expiration_time, DHTExpiration):
+            expiration_time = [expiration_time] * len(keys)
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
         in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
-        expirations = [expirations] * len(keys) if isinstance(expirations, DHTExpiration) else expirations
-        keys, values, expirations, in_cache = map(list, [keys, values, expirations, in_cache])
-        assert len(keys) == len(values) == len(expirations) == len(in_cache), "Data is not aligned"
+        keys, values, expiration_time, in_cache = map(list, [keys, values, expiration_time, in_cache])
+        assert len(keys) == len(values) == len(expiration_time) == len(in_cache), "Data is not aligned"
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
         store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
-                                             expiration=expirations, in_cache=in_cache, peer=self.node_info)
+                                             expiration_time=expiration_time, in_cache=in_cache, peer=self.node_info)
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
                 response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
                 response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
@@ -152,10 +153,10 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ Some node wants us to store this (key, value) pair """
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
             asyncio.create_task(self.rpc_ping(request.peer, context))
-        assert len(request.keys) == len(request.values) == len(request.expiration) == len(request.in_cache)
+        assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         for key_bytes, value_bytes, expiration_time, in_cache in zip(
         for key_bytes, value_bytes, expiration_time, in_cache in zip(
-                request.keys, request.values, request.expiration, request.in_cache):
+                request.keys, request.values, request.expiration_time, request.in_cache):
             local_memory = self.cache if in_cache else self.storage
             local_memory = self.cache if in_cache else self.storage
             response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
             response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
         return response
         return response
@@ -180,15 +181,16 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if response.peer and response.peer.node_id:
             if response.peer and response.peer.node_id:
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
-            assert len(response.values) == len(response.expiration) == len(response.nearest) == len(keys), \
-                "DHTProtocol: response is not aligned with keys"
+            assert len(response.values) == len(response.expiration_time) == len(response.nearest) == len(keys), \
+                "DHTProtocol: response is not aligned with keys and/or expiration times"
 
 
             output = {}  # unpack data without special NOT_FOUND_* values
             output = {}  # unpack data without special NOT_FOUND_* values
-            for key, value, expiration, nearest in zip(keys, response.values, response.expiration, response.nearest):
+            for key, value, expiration_time, nearest in zip(
+                    keys, response.values, response.expiration_time, response.nearest):
                 value = value if value != _NOT_FOUND_VALUE else None
                 value = value if value != _NOT_FOUND_VALUE else None
-                expiration = expiration if expiration != _NOT_FOUND_EXPIRATION else None
+                expiration_time = expiration_time if expiration_time != _NOT_FOUND_EXPIRATION else None
                 nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
                 nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
-                output[key] = (value, expiration, nearest)
+                output[key] = (value, expiration_time, nearest)
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
@@ -202,12 +204,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(request.peer, context))
             asyncio.create_task(self.rpc_ping(request.peer, context))
 
 
-        response = dht_pb2.FindResponse(values=[], expiration=[], nearest=[], peer=self.node_info)
+        response = dht_pb2.FindResponse(values=[], expiration_time=[], nearest=[], peer=self.node_info)
         for key_id in map(DHTID.from_bytes, request.keys):
         for key_id in map(DHTID.from_bytes, request.keys):
-            maybe_value, maybe_expiration = self.storage.get(key_id)
-            cached_value, cached_expiration = self.cache.get(key_id)
-            if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
-                maybe_value, maybe_expiration = cached_value, cached_expiration
+            maybe_value, maybe_expiration_time = self.storage.get(key_id)
+            cached_value, cached_expiration_time = self.cache.get(key_id)
+            if (cached_expiration_time or -float('inf')) > (maybe_expiration_time or -float('inf')):
+                maybe_value, maybe_expiration_time = cached_value, cached_expiration_time
 
 
             nearest_neighbors = self.routing_table.get_nearest_neighbors(
             nearest_neighbors = self.routing_table.get_nearest_neighbors(
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
                 key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id))
@@ -217,7 +219,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_ids, endpoints = [], []
                 peer_ids, endpoints = [], []
 
 
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
             response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
-            response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
+            response.expiration_time.append(maybe_expiration_time if maybe_expiration_time else _NOT_FOUND_EXPIRATION)
             response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
             response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
         return response
         return response
 
 
@@ -235,7 +237,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             if node_id not in self.routing_table:
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 # we just met a new node, maybe we know some values that it *should* store
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
                 data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
-                for key, value, expiration in list(self.storage.items()):
+                for key, value, expiration_time in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
                     if neighbors:
                         nearest_distance = neighbors[0][0].xor_distance(key)
                         nearest_distance = neighbors[0][0].xor_distance(key)
@@ -243,7 +245,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        data_to_send.append((key, value, expiration))
+                        data_to_send.append((key, value, expiration_time))
                 if data_to_send:
                 if data_to_send:
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
                     asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 
 
@@ -262,7 +264,7 @@ _NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values
 
 
 
 
 class LocalStorage:
 class LocalStorage:
-    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
+    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration_time) """
 
 
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.cache_size = maxsize or float("inf")
@@ -306,4 +308,4 @@ class LocalStorage:
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
     def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self.remove_outdated()
         self.remove_outdated()
-        return ((key, value, expiration) for key, (value, expiration) in self.data.items())
+        return ((key, value, expiration_time) for key, (value, expiration_time) in self.data.items())

+ 19 - 13
hivemind/dht/traverse.py

@@ -215,16 +215,22 @@ async def traverse_dht(
             active_workers.subtract(queries_to_call)
             active_workers.subtract(queries_to_call)
             heap_updated_event.set()
             heap_updated_event.set()
 
 
-    # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
-    await asyncio.wait([asyncio.create_task(worker()) for _ in range(num_workers)],
-                       return_when=asyncio.FIRST_COMPLETED)  # first worker finishes when the search is over
-    assert len(unfinished_queries) == 0 and search_finished_event.is_set()
-
-    if await_all_tasks:
-        await asyncio.gather(*pending_tasks)
-
-    nearest_neighbors_per_query = {
-        query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
-        for query in queries
-    }
-    return nearest_neighbors_per_query, visited_nodes
+    workers = [asyncio.create_task(worker()) for _ in range(num_workers)]
+    try:
+        # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
+        await asyncio.wait(workers, return_when=asyncio.FIRST_COMPLETED)
+        assert len(unfinished_queries) == 0 and search_finished_event.is_set()
+
+        if await_all_tasks:
+            await asyncio.gather(*pending_tasks)
+
+        nearest_neighbors_per_query = {
+            query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
+            for query in queries
+        }
+        return nearest_neighbors_per_query, visited_nodes
+
+    except asyncio.CancelledError as e:
+        for worker in workers:
+            worker.cancel()
+        raise e

+ 2 - 4
hivemind/server/connection_handler.py

@@ -59,13 +59,11 @@ class ConnectionHandler(mp.Process):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
     async def forward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
         future = self.experts[request.uid].forward_pool.submit_task(*inputs)
-        response = await future.async_result()
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
 
 
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
     async def backward(self, request: runtime_pb2.ExpertRequest, context: grpc.ServicerContext):
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         inputs_and_grad_outputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
         future = self.experts[request.uid].backward_pool.submit_task(*inputs_and_grad_outputs)
-        response = await future.async_result()
-        serialized_response = [serialize_torch_tensor(tensor) for tensor in response]
+        serialized_response = [serialize_torch_tensor(tensor) for tensor in await future]
         return runtime_pb2.ExpertResponse(tensors=serialized_response)
         return runtime_pb2.ExpertResponse(tensors=serialized_response)

+ 116 - 99
hivemind/utils/mpfuture.py

@@ -1,21 +1,25 @@
+from __future__ import annotations
+import time
 import multiprocessing as mp
 import multiprocessing as mp
 import multiprocessing.connection
 import multiprocessing.connection
-from concurrent.futures import Future, CancelledError
-from warnings import warn
+import concurrent.futures._base as base
+
 import asyncio
 import asyncio
+from functools import lru_cache
+from typing import Optional
+
+from hivemind.utils.threading import run_in_background
 
 
 
 
-class MPFuture(Future):
-    """ Multiprocessing version of concurrent.futures.Future, interacts between two processes via Pipe """
-    STATES = 'pending', 'running', 'cancelled', 'finished', 'exception'
-    STATE_PENDING, STATE_RUNNING, STATE_CANCELLED, STATE_FINISHED, STATE_EXCEPTION = STATES
+class MPFuture(base.Future):
+    """ Multiprocessing version of concurrent.futures.Future. Can also be awaited like asyncio.Future """
+
+    TERMINAL_STATES = {base.FINISHED, base.CANCELLED, base.CANCELLED_AND_NOTIFIED}
 
 
     def __init__(self, connection: mp.connection.Connection):
     def __init__(self, connection: mp.connection.Connection):
         """ manually create MPFuture. Please use MPFuture.make_pair instead """
         """ manually create MPFuture. Please use MPFuture.make_pair instead """
+        self._state, self._result, self._exception = base.PENDING, None, None
         self.connection = connection
         self.connection = connection
-        self.state = self.STATE_PENDING
-        self._result = None
-        self._exception = None
 
 
     @classmethod
     @classmethod
     def make_pair(cls):
     def make_pair(cls):
@@ -23,125 +27,138 @@ class MPFuture(Future):
         connection1, connection2 = mp.Pipe()
         connection1, connection2 = mp.Pipe()
         return cls(connection1), cls(connection2)
         return cls(connection1), cls(connection2)
 
 
-    def poll_and_recv(self, timeout):
-        available = self.connection.poll(timeout)
-        if not available:
-            raise TimeoutError
-        try:
-            status, payload = self.connection.recv()
-            self.connection.close()
-        except BrokenPipeError as e:
-            status, payload = self.STATE_EXCEPTION, e
-        return status, payload
-
-    def _recv(self, timeout):
-
-        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            status, payload = self.poll_and_recv(timeout)
-
-            assert status in self.STATES
-            self.state = status
-
-            if status == self.STATE_FINISHED:
-                self._result = payload
-            elif status == self.STATE_EXCEPTION:
-                self._exception = payload
-            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
-                pass  # only update self.state
-            else:
-                raise ValueError("Result status should not be self.STATE_PENDING")
-
-    def set_result(self, result):
+    def _send_updates(self):
+        """ Send updates to a paired MPFuture """
         try:
         try:
-            self.state, self._result = self.STATE_FINISHED, result
-            self.connection.send((self.STATE_FINISHED, result))
-            self.connection.close()
+            self.connection.send((self._state, self._result, self._exception))
+            if self._state in self.TERMINAL_STATES:
+                self._shutdown_trigger.set_result(True)
+                self.connection.close()
             return True
             return True
         except BrokenPipeError:
         except BrokenPipeError:
             return False
             return False
 
 
-    def set_exception(self, exception: BaseException):
+    def _recv_updates(self, timeout: Optional[float]):
+        """ Await updates from a paired MPFuture """
         try:
         try:
-            self.state, self._exception = self.STATE_EXCEPTION, exception
-            self.connection.send((self.STATE_EXCEPTION, exception))
-            self.connection.close()
-            return True
-        except BrokenPipeError:
-            return False
+            future = base.wait([run_in_background(self.connection.poll, timeout), self._shutdown_trigger],
+                               return_when=base.FIRST_COMPLETED)[0].pop()
+            if future is self._shutdown_trigger:
+                raise BrokenPipeError()
+            if not future.result():
+                raise TimeoutError()
+            self._state, result, exception = self.connection.recv()
+            self._result = result if result is not None else self._result
+            self._exception = exception if exception is not None else self._exception
+            if self._state in self.TERMINAL_STATES:
+                self.connection.close()
+        except TimeoutError as e:
+            raise e
+        except (BrokenPipeError, OSError) as e:
+            if self._state in (base.PENDING, base.RUNNING):
+                self._state, self._exception = base.FINISHED, e
+
+    def _await_terminal_state(self, timeout: Optional[float]):
+        """ Await updates until future is either finished, cancelled or got an exception """
+        time_left = float('inf') if timeout is None else timeout
+        time_before = time.monotonic()
+        while self._state not in self.TERMINAL_STATES and time_left > 0:
+            self._recv_updates(time_left if timeout else None)
+            time_spent = time.monotonic() - time_before
+            time_left, time_before = time_left - time_spent, time_before + time_spent
+
+    def _sync_updates(self):
+        """ Apply queued updates from a paired MPFuture without waiting for new ones """
+        try:
+            self._recv_updates(timeout=0)
+        except TimeoutError:
+            pass
+
+    def set_result(self, result):
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            raise RuntimeError(f"Can't set_result to a future that is in {self._state}")
+        self._state, self._result = base.FINISHED, result
+        return self._send_updates()
+
+    def set_exception(self, exception: BaseException):
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            raise RuntimeError(f"Can't set_exception to a future that is in {self._state}")
+        self._state, self._exception = base.FINISHED, exception
+        self._send_updates()
 
 
     def set_running_or_notify_cancel(self):
     def set_running_or_notify_cancel(self):
-        return True
+        self._sync_updates()
+        if self._state == base.PENDING:
+            self._state = base.RUNNING
+            return self._send_updates()
+        elif self._state == base.CANCELLED:
+            return False
+        else:
+            raise RuntimeError(f"Can't set_running_or_notify_cancel to a future that is in {self._state}")
 
 
     def cancel(self):
     def cancel(self):
-        raise NotImplementedError()
+        self._sync_updates()
+        if self._state in self.TERMINAL_STATES:
+            return False
+        self._state, self._exception = base.CANCELLED, base.CancelledError()
+        return self._send_updates()
 
 
     def result(self, timeout=None):
     def result(self, timeout=None):
-        self._recv(timeout)
-        if self.state == self.STATE_FINISHED:
-            return self._result
-        elif self.state == self.STATE_EXCEPTION:
+        self._await_terminal_state(timeout)
+        if self._exception is not None:
             raise self._exception
             raise self._exception
-        else:
-            assert self.state == self.STATE_CANCELLED
-            raise CancelledError()
+        return self._result
 
 
     def exception(self, timeout=None):
     def exception(self, timeout=None):
-        self._recv(timeout)
+        self._await_terminal_state(timeout)
+        if self._state == base.CANCELLED:
+            raise base.CancelledError()
         return self._exception
         return self._exception
 
 
     def done(self):
     def done(self):
-        return self.state in (self.STATE_FINISHED, self.STATE_EXCEPTION, self.STATE_CANCELLED)
+        self._sync_updates()
+        return self._state in self.TERMINAL_STATES
 
 
     def running(self):
     def running(self):
-        return self.state == self.STATE_RUNNING
+        self._sync_updates()
+        return self._state == base.RUNNING
 
 
     def cancelled(self):
     def cancelled(self):
-        warn("cancelled not implemented")
-        return False
+        self._sync_updates()
+        return self._state == base.CANCELLED
 
 
     def add_done_callback(self, callback):
     def add_done_callback(self, callback):
-        raise NotImplementedError()
-
-    def __repr__(self):
-        try:
-            self._recv(timeout=0)
-        except TimeoutError:
-            pass
-        if self.state == self.STATE_FINISHED:
-            return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
-        elif self.state == self.STATE_EXCEPTION:
-            return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
-        else:
-            return "<MPFuture at 0x{:x} state={}>".format(id(self), self.state)
+        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
 
 
-    async def _async_recv(self, timeout):
-        loop = asyncio.get_running_loop()
+    def remove_done_callback(self, callback):
+        raise NotImplementedError(f"MPFuture doesn't support callbacks.")
 
 
-        if self.state in (self.STATE_PENDING, self.STATE_RUNNING):
-            status, payload = await loop.run_in_executor(None, self.poll_and_recv, timeout)
+    def get_loop(self):
+        raise NotImplementedError(f"MPFuture doesn't support get_loop")
 
 
-            assert status in self.STATES
-            self.state = status
+    @property
+    @lru_cache()
+    def _shutdown_trigger(self):
+        return base.Future()
 
 
-            if status == self.STATE_FINISHED:
-                self._result = payload
-            elif status == self.STATE_EXCEPTION:
-                self._exception = payload
-            elif status in (self.STATE_RUNNING, self.STATE_CANCELLED):
-                pass  # only update self.state
+    def __repr__(self):
+        self._sync_updates()
+        if self._state == base.FINISHED:
+            if self._exception:
+                return "<MPFuture at 0x{:x} state=finished raised {}>".format(id(self), type(self._exception))
             else:
             else:
-                raise ValueError("Result status should not be self.STATE_PENDING")
+                return "<MPFuture at 0x{:x} state=finished returned {}>".format(id(self), type(self._result))
+        else:
+            return "<MPFuture at 0x{:x} state={}>".format(id(self), self._state)
 
 
-    async def async_result(self, timeout=None):
-        await self._async_recv(timeout)
-        if self.state == self.STATE_FINISHED:
-            return self._result
-        elif self.state == self.STATE_EXCEPTION:
+    def __await__(self):
+        yield from asyncio.get_running_loop().run_in_executor(None, self._await_terminal_state, None).__await__()
+        if self._exception:
             raise self._exception
             raise self._exception
-        else:
-            assert self.state == self.STATE_CANCELLED
-            raise CancelledError()
+        return self._result
 
 
-    async def async_exception(self, timeout=None):
-        await self._async_recv(timeout)
-        return self._exception
+    def __del__(self):
+        self._shutdown_trigger.set_result(True)
+        self.connection.close()

+ 6 - 13
hivemind/utils/threading.py

@@ -1,27 +1,20 @@
 import os
 import os
-from concurrent.futures import Future, ThreadPoolExecutor, as_completed, TimeoutError
+from concurrent.futures import Future, as_completed, TimeoutError, ThreadPoolExecutor
 import time
 import time
 from typing import Optional, List
 from typing import Optional, List
 
 
-GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+EXECUTOR_PID, GLOBAL_EXECUTOR = None, None
 
 
 
 
 def run_in_background(func: callable, *args, **kwargs) -> Future:
 def run_in_background(func: callable, *args, **kwargs) -> Future:
     """ run func(*args, **kwargs) in background and return Future for its outputs """
     """ run func(*args, **kwargs) in background and return Future for its outputs """
-
+    global EXECUTOR_PID, GLOBAL_EXECUTOR
+    if os.getpid() != EXECUTOR_PID:
+        GLOBAL_EXECUTOR = ThreadPoolExecutor(max_workers=os.environ.get("HIVEMIND_THREADS", float('inf')))
+        EXECUTOR_PID = os.getpid()
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
     return GLOBAL_EXECUTOR.submit(func, *args, **kwargs)
 
 
 
 
-def run_forever(func: callable, *args, **kwargs):
-    """ A function that runs a :func: in background forever. Returns a future that catches exceptions """
-
-    def repeat():
-        while True:
-            func(*args, **kwargs)
-
-    return run_in_background(repeat)
-
-
 def run_and_await_k(jobs: List[callable], k: int,
 def run_and_await_k(jobs: List[callable], k: int,
                     timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
                     timeout_after_k: Optional[float] = 0, timeout_total: Optional[float] = None):
     """
     """

+ 6 - 7
tests/benchmark_dht.py

@@ -15,15 +15,15 @@ def random_endpoint() -> hivemind.Endpoint:
 
 
 
 
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
 def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
-                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration_time: float):
-    old_expiration_time, hivemind.DHT.EXPIRATION = hivemind.DHT.EXPIRATION, expiration_time
+                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
     random.seed(random_seed)
     random.seed(random_seed)
 
 
     print("Creating peers...")
     print("Creating peers...")
     peers = []
     peers = []
     for _ in trange(num_peers):
     for _ in trange(num_peers):
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
         neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
-        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
+        peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
+                            expiration=expiration, listen_on=f'0.0.0.0:*')
         peers.append(peer)
         peers.append(peer)
 
 
     store_peer, get_peer = peers[-2:]
     store_peer, get_peer = peers[-2:]
@@ -52,7 +52,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
     time.sleep(wait_before_read)
     time.sleep(wait_before_read)
 
 
-    if time.perf_counter() - benchmark_started > expiration_time:
+    if time.perf_counter() - benchmark_started > expiration:
         warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
         warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
 
 
     successful_gets = total_get_time = 0
     successful_gets = total_get_time = 0
@@ -67,7 +67,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
                     and expert.endpoint == endpoints[start // expert_batch_size]:
                     and expert.endpoint == endpoints[start // expert_batch_size]:
                 successful_gets += 1
                 successful_gets += 1
 
 
-    if time.perf_counter() - benchmark_started > expiration_time:
+    if time.perf_counter() - benchmark_started > expiration:
         warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
         warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
 
 
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
     print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
@@ -75,7 +75,6 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
 
 
     alive_peers = [peer.is_alive() for peer in peers]
     alive_peers = [peer.is_alive() for peer in peers]
     print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
     print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
-    hivemind.DHT.EXPIRATION = old_expiration_time
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -84,7 +83,7 @@ if __name__ == "__main__":
     parser.add_argument('--initial_peers', type=int, default=1, required=False)
     parser.add_argument('--initial_peers', type=int, default=1, required=False)
     parser.add_argument('--num_experts', type=int, default=256, required=False)
     parser.add_argument('--num_experts', type=int, default=256, required=False)
     parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
     parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
-    parser.add_argument('--expiration_time', type=float, default=300, required=False)
+    parser.add_argument('--expiration', type=float, default=300, required=False)
     parser.add_argument('--wait_after_request', type=float, default=0, required=False)
     parser.add_argument('--wait_after_request', type=float, default=0, required=False)
     parser.add_argument('--wait_before_read', type=float, default=0, required=False)
     parser.add_argument('--wait_before_read', type=float, default=0, required=False)
     parser.add_argument('--wait_timeout', type=float, default=5, required=False)
     parser.add_argument('--wait_timeout', type=float, default=5, required=False)

+ 2 - 2
tests/test_dht.py

@@ -5,7 +5,7 @@ import random
 import heapq
 import heapq
 import uuid
 import uuid
 from itertools import chain
 from itertools import chain
-from typing import Optional, Tuple
+from typing import Optional
 import numpy as np
 import numpy as np
 
 
 import hivemind
 import hivemind
@@ -249,7 +249,7 @@ def test_dht_node():
         # test 7: bulk store and bulk get
         # test 7: bulk store and bulk get
         keys = 'foo', 'bar', 'baz', 'zzz'
         keys = 'foo', 'bar', 'baz', 'zzz'
         values = 3, 2, 'batman', [1, 2, 3]
         values = 3, 2, 'batman', [1, 2, 3]
-        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration=get_dht_time() + 999))
+        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
         assert all(store_ok.values()), "failed to store one or more keys"
         assert all(store_ok.values()), "failed to store one or more keys"
         response = loop.run_until_complete(me.get_many(keys[::-1]))
         response = loop.run_until_complete(me.get_many(keys[::-1]))
         for key, value in zip(keys, values):
         for key, value in zip(keys, values):

+ 116 - 0
tests/test_util_modules.py

@@ -0,0 +1,116 @@
+import asyncio
+
+import pytest
+import hivemind
+
+from concurrent.futures import CancelledError
+
+
+def test_mpfuture_result():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    f1.set_result(321)
+    assert f2.result() == 321
+    assert f1.result() == 321
+
+    for future in [f1, f2]:
+        with pytest.raises(RuntimeError):
+            future.set_result(123)
+        with pytest.raises(RuntimeError):
+            future.set_exception(ValueError())
+        assert future.cancel() is False
+        assert future.done() and not future.running() and not future.cancelled()
+
+    f1, f2 = hivemind.MPFuture.make_pair()
+    with pytest.raises(TimeoutError):
+        f1.result(timeout=1e-3)
+
+    f2.set_result(['abacaba', 123])
+    assert f1.result() == ['abacaba', 123]
+
+
+def test_mpfuture_exception():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    with pytest.raises(TimeoutError):
+        f1.exception(timeout=1e-3)
+
+    f2.set_exception(NotImplementedError())
+
+    for future in [f1, f2]:
+        assert isinstance(future.exception(), NotImplementedError)
+        with pytest.raises(NotImplementedError):
+            future.result()
+        assert future.cancel() is False
+        assert future.done() and not future.running() and not future.cancelled()
+
+
+def test_mpfuture_cancel():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    assert not f2.cancelled()
+    f1.cancel()
+    for future in [f1, f2]:
+        with pytest.raises(CancelledError):
+            future.result()
+        with pytest.raises(CancelledError):
+            future.exception()
+        with pytest.raises(RuntimeError):
+            future.set_result(123)
+        with pytest.raises(RuntimeError):
+            future.set_exception(NotImplementedError)
+        assert future.cancelled() and future.done() and not future.running()
+
+
+def test_mpfuture_status():
+    f1, f2 = hivemind.MPFuture.make_pair()
+    assert f1.set_running_or_notify_cancel() is True
+    for future in [f1, f2]:
+        assert future.running() and not future.done() and not future.cancelled()
+        with pytest.raises(RuntimeError):
+            future.set_running_or_notify_cancel()
+    f2.cancel()
+    for future in [f1, f2]:
+        assert not future.running() and future.done() and future.cancelled()
+        assert future.set_running_or_notify_cancel() is False
+
+    f1, f2 = hivemind.MPFuture.make_pair()
+    f1.cancel()
+    for future in [f1, f2]:
+        assert future.set_running_or_notify_cancel() is False
+
+
+def test_await_mpfuture():
+    async def _run():
+        # await result
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_assign():
+            assert f2.set_running_or_notify_cancel() is True
+            await asyncio.sleep(0.1)
+            f2.set_result((123, 'ololo'))
+
+        asyncio.create_task(wait_and_assign())
+        for future in [f1, f2]:
+            res = await future
+            assert res == (123, 'ololo')
+
+        # await cancel
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_cancel():
+            await asyncio.sleep(0.1)
+            f1.cancel()
+
+        asyncio.create_task(wait_and_cancel())
+        for future in [f1, f2]:
+            with pytest.raises(CancelledError):
+                await future
+
+        # await exception
+        f1, f2 = hivemind.MPFuture.make_pair()
+        async def wait_and_raise():
+            await asyncio.sleep(0.1)
+            f1.set_exception(SystemError())
+
+        asyncio.create_task(wait_and_raise())
+        for future in [f1, f2]:
+            with pytest.raises(SystemError):
+                await future
+
+    asyncio.new_event_loop().run_until_complete(_run())

+ 9 - 11
tests/test_utils/run_server.py

@@ -14,7 +14,7 @@ from test_utils.layers import name_to_block, name_to_input
 def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
 def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
-                      UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
+                      start=False, **kwargs) -> hivemind.Server:
     """
     """
     Instantiate a server with several identical experts. See argparse comments below for details
     Instantiate a server with several identical experts. See argparse comments below for details
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
     :param listen_on: network interface with address and (optional) port, e.g. "127.0.0.1:1337" or "[::]:80"
@@ -47,9 +47,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     if not no_dht:
     if not no_dht:
         if not len(initial_peers):
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
             print("No initial peers provided. Starting additional dht as an initial peer.")
-            dht_root = hivemind.DHT(initial_peers=initial_peers,
-                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}",
-                                    start=True)
+            dht_root = hivemind.DHT(initial_peers=initial_peers, start=True,
+                                    listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}")
             print(f"Initializing DHT with port {dht_root.port}")
             print(f"Initializing DHT with port {dht_root.port}")
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
             initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
         else:
@@ -57,9 +56,8 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
             if root_port is not None:
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
 
-        dht = hivemind.DHT(initial_peers=initial_peers,
-                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}",
-                           start=True)
+        dht = hivemind.DHT(initial_peers=initial_peers, start=True,
+                           listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}")
         if verbose:
         if verbose:
             print(f"Running dht node on port {dht.port}")
             print(f"Running dht node on port {dht.port}")
 
 
@@ -74,7 +72,7 @@ def make_dummy_server(listen_on='0.0.0.0:*', num_experts=1, expert_cls='ffn', hi
     for i in range(num_experts):
     for i in range(num_experts):
         expert = name_to_block[expert_cls](hidden_dim)
         expert = name_to_block[expert_cls](hidden_dim)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
         opt = torch.optim.SGD(expert.parameters(), 0.0 if no_optimizer else 0.05)
-        expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
+        expert_uid = f'{expert_prefix}{hivemind.DHT.UID_DELIMITER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
                                                      args_schema=args_schema,
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
                                                      outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
@@ -154,12 +152,12 @@ if __name__ == '__main__':
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_optimizer', action='store_true', help='if specified, all optimizers use learning rate=0')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--no_dht', action='store_true', help='if specified, the server will not be attached to a dht')
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
     parser.add_argument('--initial_peers', type=str, default="[]", required=False, help='a list of peers that will'
-                                                                                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
+                        ' introduce this node to the dht, e.g. [("1.2.3.4", 1337), ("127.0.0.1", 4321)]')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--dht_port', type=int, default=None, required=False, help='DHT node will listen on this port')
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
     parser.add_argument('--root_port', type=int, default=None, required=False, help='If this server does not have peers'
-                                                                                    ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
+                        ', it will create a virtual dht node on this port. You can then use this node as initial peer.')
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
     parser.add_argument('--increase_file_limit', action='store_true', help='On *nix, this will increase the max number'
-                                                                           ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
+                        ' of processes a server can spawn before hitting "Too many open files"; Use at your own risk.')
 
 
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())