Explorar el Código

Parallel traverse + add benchmark_dht + update docs + minor fixes (#53)

* Work in progress parallel DHT

* add heap priority based on visited distance to queries

* add heap priority based on visited distance to queries

* compute priority on unfinished queries only

* typo

* add docstring for new traverse_dht

* trigger rtfd

* typo

* edge case: no friends

* propagate default num replicas to the protocol

* pep

* use listen_on in DHT

* use listen_on in DHT

* [personal] Parallel traverse - rework dht benchmark (#62)

* use listen_on in DHT

* priority: handle empty heap

* priority: handle empty heap

* im

* update tests (tuple endpoint -> str endpoint, new find_nearest_nodes)

* typo

* Updated dht benchmark

* fixed dependencies

* pre-review

* pre-review

* fixed benchmark dht

* mid-review

* fixed benchmark dht

* set default num_replicas = 3

* add benchmark_dht.py to circleci

* add benchmark_dht.py to circleci

Co-authored-by: Vsevolod-pl and Yozh <vsevolod-pl@yandex.ru>

* implement get_many and store_many, add tests

* traverse_dht - make sure all queries are finished eventually

* edge case: (-inf <= inf) == True

* todo: strict type check

* minor: break early on first success

* faster worker termination

* always finish search

* always finish search

* always finish search

* always finish search

* always finish search

* always finish search

* update documentation

* traverse_dht: simpler priority heuristic, better docstrings

* sphinx, pep8

* handle empty queries

* handle empty queries

* review: remove task.cancel to avoid potential side-effects

* await all tasks by default

* await all tasks by default

* add subtitle

* typo hidemind -> hivemind

* docstring

* Update hivemind/dht/node.py

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* review: tuple(queries)

* review: move increase_file_limit to a separate function

* review: remove hard-to-read ifelse for num_replicas

* unused imports

* review: replace tuple with flavour names with namedtuple

* remove \*

* increase num replicas (success rate 99.93 -> 99.98 on 1000 nodes)

* make benchmark_dht into a function

* typo

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Vsevolod-pl hace 5 años
padre
commit
6b6aded54a

+ 5 - 2
.circleci/config.yml

@@ -20,7 +20,7 @@ jobs:
               cd -
             fi
           name: compile-grpc  # remove this command when v1.31 becomes available via pip install -r requirements.txt
-      - run: sudo pip install codecov pytest grpcio-tools
+      - run: sudo pip install codecov pytest grpcio-tools tqdm
       - python/install-deps
       - python/save-cache
       - run:
@@ -31,7 +31,10 @@ jobs:
           name: tests
       - run:
           command: python tests/benchmark_throughput.py --preset minimalistic
-          name: benchmark
+          name: benchmark_throughput
+      - run:
+          command: python tests/benchmark_dht.py
+          name: benchmark_dht
       - run:
           command: codecov
           name: codecov

+ 19 - 6
docs/modules/dht.rst

@@ -1,13 +1,17 @@
-``hidemind.dht``
+**Hivemind DHT**
 ====================
 
-.. image:: ../_static/dht.png
-   :width: 800
-
 .. automodule:: hivemind.dht
-
 .. currentmodule:: hivemind.dht
 
+Here's a high level scheme of how these components interact with one another:
+
+.. image:: ../_static/dht.png
+   :width: 640
+   :align: center
+
+DHT and DHTNode
+###############
 
 .. autoclass:: DHT
    :members:
@@ -18,6 +22,9 @@
    :members:
    :member-order: bysource
 
+DHT communication protocol
+##########################
+.. automodule:: hivemind.dht.protocol
 .. currentmodule:: hivemind.dht.protocol
 
 .. autoclass:: DHTProtocol
@@ -39,6 +46,12 @@
    :exclude-members: HASH_FUNC
    :member-order: bysource
 
-.. currentmodule:: hivemind.dht.search
+Traverse (crawl) DHT
+####################
+
+.. automodule:: hivemind.dht.traverse
+.. currentmodule:: hivemind.dht.traverse
+
+.. autofunction:: simple_traverse_dht
 
 .. autofunction:: traverse_dht

+ 47 - 27
hivemind/dht/__init__.py

@@ -1,13 +1,19 @@
 """
-This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
- * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
- * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
- * class DHTProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
+This is a Distributed Hash Table optimized for rapidly accessing a lot of lightweight metadata.
+Hivemind DHT is based on Kademlia [1] with added support for improved bulk store/get operations and caching.
 
-The code in this module is a modified version of https://github.com/bmuller/kademlia
-Brian, if you're reading this: THANK YOU! you're awesome :)
+The code is organized as follows:
+
+ * **class DHT (__init__.py)** - high-level class for model training. Runs DHTNode in a background process.
+ * **class DHTNode (node.py)** - an asyncio implementation of dht server, stores AND gets keys.
+ * **class DHTProtocol (protocol.py)** - an RPC protocol to request data from dht nodes.
+ * **async def traverse_dht (traverse.py)** - a search algorithm that crawls DHT peers.
+
+- [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
+- [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
 """
 import asyncio
+import ctypes
 import multiprocessing as mp
 import warnings
 from typing import List, Optional
@@ -25,22 +31,25 @@ class DHT(mp.Process):
     A high-level interface to hivemind DHT. Runs a dht node in a background process.
 
     :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
-    :param port: a port where DHT node will listen to incoming connections. Defaults to hivemind.utils.find_open_port
+    :param listen_on: an interface for incoming connections, e.g. "127.0.0.1:*", "0.0.0.0:1234" or "ipv6:[::]:*"
     :param start: if True, automatically starts the background process on creation. Otherwise await manual start
     :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
-    :param node_params: any other params will be forwarded to DHTNode upon creation
+    :param max_workers: declare_experts and get_experts will use up to this many parallel workers
+        (but no more than one per key)
+    :param kwargs: any other params will be forwarded to DHTNode upon creation
     """
     UID_DELIMETER = '.'  # splits expert uids over this delimeter
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
 
-    def __init__(self, *initial_peers: Endpoint, port: Optional[Port] = None,
-                 start: bool, daemon: bool = True, **node_params):
+    def __init__(self, *initial_peers: Endpoint, listen_on: Endpoint = "0.0.0.0:*", start: bool, daemon: bool = True,
+                 max_workers: Optional[int] = None, parallel_rpc: Optional[int] = None, **kwargs):
         super().__init__()
-        port = find_open_port() if port is None else port
-        self.node: Optional[DHTNode] = None  # to be initialized in self.run
-        self.port, self.initial_peers, self.node_params = port, initial_peers, node_params
-        self._pipe, self.pipe = mp.Pipe(duplex=False)
+        self.listen_on, self.initial_peers, self.kwargs = listen_on, initial_peers, kwargs
+        self.max_workers, self.parallel_rpc = max_workers, parallel_rpc
+        self._port = mp.Value(ctypes.c_int32, 0)  # initialized after server starts
+        self.node: Optional[DHTNode] = None  # initialized inside self.run only
+        self._pipe, self.pipe = mp.Pipe(duplex=True)
         self.ready = mp.Event()
         self.daemon = daemon
         if start:
@@ -53,8 +62,10 @@ class DHT(mp.Process):
         uvloop.install()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
-        self.node = loop.run_until_complete(DHTNode.create(
-            initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
+        self.node: DHTNode = loop.run_until_complete(DHTNode.create(
+            initial_peers=list(self.initial_peers), listen_on=self.listen_on, parallel_rpc=self.parallel_rpc,
+            num_workers=self.max_workers or 1, **self.kwargs))
+        self._port.value = self.node.port
         run_in_background(loop.run_forever)
         self.ready.set()
 
@@ -78,6 +89,10 @@ class DHT(mp.Process):
         else:
             warnings.warn("DHT shutdown has no effect: dht process is already not alive")
 
+    @property
+    def port(self) -> Optional[int]:
+        return self._port.value if self._port.value != 0 else None
+
     def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
         """
         :param uids: find experts with these ids from across the DHT
@@ -91,13 +106,15 @@ class DHT(mp.Process):
     def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
         loop = asyncio.get_event_loop()
         expiration = expiration or get_dht_time()
+        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
+        keys = [self.make_key('expert', uid) for uid in uids]
 
-        lookup_futures = [asyncio.run_coroutine_threadsafe(
-            self.node.get(self.make_key('expert', uid), expiration), loop) for uid in uids]
+        response = asyncio.run_coroutine_threadsafe(
+            self.node.get_many(keys, expiration, num_workers=num_workers), loop).result()
 
         experts: List[Optional[RemoteExpert]] = [None] * len(uids)
-        for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
-            maybe_result, maybe_expiration = lookup.result()
+        for i, (key, uid) in enumerate(zip(keys, uids)):
+            maybe_result, maybe_expiration = response[key]
             if maybe_expiration is not None:  # if we found a value
                 experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
 
@@ -121,25 +138,28 @@ class DHT(mp.Process):
 
     def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
         assert self.node is not None, "This method should only be accessed from inside .run method"
+        num_workers = len(uids) if self.max_workers is None else min(len(uids), self.max_workers)
         loop = asyncio.get_event_loop()
         expiration_time = get_dht_time() + self.EXPIRATION
         unique_prefixes = set()
         coroutines = []
 
+        keys, values = [], []
         for uid in uids:
-            coroutines.append(asyncio.run_coroutine_threadsafe(
-                self.node.store(self.make_key('expert', uid), value=(addr, port),
-                                expiration_time=expiration_time),
-                loop))
             uid_parts = uid.split(self.UID_DELIMETER)
+            keys.append(self.make_key('expert', uid))
+            values.append((addr, port))
             unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
 
         for prefix in unique_prefixes:
-            coroutines.append(asyncio.run_coroutine_threadsafe(
-                self.node.store(self.make_key('prefix', prefix), True, expiration_time), loop))
+            keys.append(self.make_key('prefix', prefix))
+            values.append(True)
 
+        store_ok = asyncio.run_coroutine_threadsafe(
+            self.node.store_many(keys, values, expiration_time, num_workers=num_workers), loop
+        ).result()
         if future is not None:
-            future.set_result([coro.result() for coro in coroutines])  # wait for all coroutings to finish
+            future.set_result([store_ok[key] for key in keys])
 
     def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
         """

+ 221 - 119
hivemind/dht/node.py

@@ -1,13 +1,13 @@
 from __future__ import annotations
 import asyncio
 import random
-from collections import OrderedDict
-from typing import Optional, Tuple, List, Dict
+from collections import namedtuple
+from typing import Optional, Tuple, List, Dict, Collection, Union, Set
 from warnings import warn
 
 from .protocol import DHTProtocol
-from .routing import DHTID, BinaryDHTValue, DHTExpiration, DHTKey, get_dht_time, DHTValue
-from .search import traverse_dht
+from .routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
+from .traverse import traverse_dht
 from ..utils import Endpoint, LOCALHOST, MSGPackSerializer
 
 
@@ -43,16 +43,18 @@ class DHTNode:
       Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
 
     """
-    node_id: int; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; refresh_timeout: float
-    protocol: DHTProtocol
+    # fmt:off
+    node_id: DHTID; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; num_workers: int
+    refresh_timeout: float; protocol: DHTProtocol
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+    # fmt:on
 
     @classmethod
     async def create(
             cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
-            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, parallel_rpc: int = None,
+            bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
-            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
+            num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
             listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
@@ -68,20 +70,21 @@ class DHTNode:
         :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
           if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
         :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+        :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param cache_locally: if True, caches all values (stored or found) in a node-local cache
         :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
           nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
         :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
-        :param listen_on: network interface for incoming RPCs, e.g. "0.0.0.0:1337" or "localhost:\*" or "[::]:7654"
+        :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
         :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         """
         self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
-        self.num_replicas = num_replicas if num_replicas is not None else bucket_size
+        self.num_replicas, self.num_workers = num_replicas, num_workers
         self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
         self.refresh_timeout = refresh_timeout
 
@@ -89,7 +92,6 @@ class DHTNode:
                                                  parallel_rpc, cache_size, listen, listen_on, **kwargs)
         self.port = self.protocol.port
 
-
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
@@ -111,7 +113,7 @@ class DHTNode:
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
-            await asyncio.wait([asyncio.create_task(self.find_nearest_nodes(key_id=self.node_id)),
+            await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
                                 asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
                                return_when=asyncio.FIRST_COMPLETED)
 
@@ -128,145 +130,245 @@ class DHTNode:
         """ Process existing requests, close all connections and stop the server """
         await self.protocol.shutdown(timeout)
 
-    async def find_nearest_nodes(self, key_id: DHTID, k_nearest: Optional[int] = None,
-                                 beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
+    async def find_nearest_nodes(
+            self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
+            num_workers: Optional[int] = None, node_to_endpoint: Optional[Dict[DHTID, Endpoint]] = None,
+            exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, Endpoint]]:
         """
-        Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
-
-        :returns: an ordered dictionary of [peer DHTID -> network Endpoint], ordered from nearest to farthest neighbor
-        :note: this is a thin wrapper over dht.search.traverse_dht, look there for more details
+        :param queries: find k nearest nodes for each of these DHTIDs
+        :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
+        :param beam_size: replacement for self.beam_size, see traverse_dht beam_size param
+        :param num_workers: replacement for self.num_workers, see traverse_dht num_workers param
+        :param node_to_endpoint: if specified, uses this dict[node_id => endpoint] as initial peers
+        :param exclude_self: if True, nearest nodes will not contain self.node_id (default = use local peers)
+        :param kwargs: additional params passed to traverse_dht
+        :returns: for every query, return nearest peers ordered dict[peer DHTID -> network Endpoint], nearest-first
         """
+        queries = tuple(queries)
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
+        num_workers = num_workers if num_workers is not None else self.num_workers
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
-        node_to_addr = dict(
-            self.protocol.routing_table.get_nearest_neighbors(key_id, beam_size, exclude=self.node_id))
-
-        async def get_neighbors(node_id: DHTID) -> Tuple[List[DHTID], bool]:
-            response = await self.protocol.call_find(node_to_addr[node_id], [key_id])
-            if not response or key_id not in response:
-                return [], False  # False means "do not interrupt search"
-
-            peers: Dict[DHTID, Endpoint] = response[key_id][-1]
-            node_to_addr.update(peers)
-            return list(peers.keys()), False  # False means "do not interrupt search"
+        if k_nearest > beam_size:
+            warn("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
+        if node_to_endpoint is None:
+            node_to_endpoint: Dict[DHTID, Endpoint] = dict()
+            for query in queries:
+                node_to_endpoint.update(
+                    self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
+
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
+            if not response:
+                return {query: ([], False) for query in queries}
+
+            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
+            for query, (_, _, peers) in response.items():
+                node_to_endpoint.update(peers)
+                output[query] = list(peers.keys()), False  # False means "do not interrupt search"
+            return output
 
         nearest_nodes, visited_nodes = await traverse_dht(
-            query_id=key_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
-            get_neighbors=get_neighbors, visited_nodes=(self.node_id,))
-
-        if not exclude_self:
-            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=key_id.xor_distance)[:k_nearest]
-            node_to_addr[self.node_id] = (LOCALHOST, self.port)
-
-        return OrderedDict((node, node_to_addr[node]) for node in nearest_nodes)
-
-    async def store(self, key: DHTKey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
+            queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
+            queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
+            visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
+
+        nearest_nodes_per_query = {}
+        for query, nearest_nodes in nearest_nodes.items():
+            if not exclude_self:
+                nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
+                node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
+            nearest_nodes_per_query[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
+        return nearest_nodes_per_query
+
+    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
         """
-        Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
-        Optionally cache (key, value, expiration) on nodes you met along the way (see Section 2.1 end) TODO(jheuristic)
+        Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
 
+        :note: store is a simplified interface to store_many, all kwargs are be forwarded there
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
-        key_id, value_bytes = DHTID.generate(source=key), self.serializer.dumps(value)
-        nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
-        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, [key_id], [value_bytes], [expiration_time]))
-                 for endpoint in nearest_node_to_addr.values()]
-        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+        store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
+        return store_ok[key]
 
-        return any(store_ok for response in done for store_ok in response.result())
+    async def store_many(
+            self, keys: List[DHTKey], values: List[DHTValue], expiration: Union[DHTExpiration, List[DHTExpiration]],
+            exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
+        """
+        Traverse DHT to find up to best nodes to store multiple (key, value, expiration) pairs.
+
+        :param keys: arbitrary serializable keys associated with each value
+        :param values: serializable "payload" for each key
+        :param expiration: either one expiration time for all keys or individual expiration times (see class doc)
+        :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
+        :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
+        :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
+        :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background
+            if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
+        :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        expiration = [expiration] * len(keys) if isinstance(expiration, DHTExpiration) else expiration
+        assert len(keys) == len(values) == len(expiration), "Please provide equal number of keys, values and expiration"
+
+        key_ids = list(map(DHTID.generate, keys))
+        id_to_original_key = dict(zip(key_ids, keys))
+        binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
+        expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration)}
+        unfinished_key_ids = set(key_ids)  # we use this set to ensure that each store request is finished
+
+        store_ok = {key: False for key in keys}  # outputs, updated during search
+        store_finished_events = {key: asyncio.Event() for key in keys}
+
+        if self.cache_locally:
+            for key_id in key_ids:
+                self.protocol.cache.store(key_id, binary_values_by_key_id[key_id], expiration_by_key_id[key_id])
+
+        # pre-populate node_to_endpoint
+        node_to_endpoint: Dict[DHTID, Endpoint] = dict()
+        for key_id in key_ids:
+            node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
+                key_id, self.protocol.bucket_size, exclude=self.node_id))
 
-    async def get(self, key: DHTKey, sufficient_expiration_time: Optional[DHTExpiration] = None,
-                  beam_size: Optional[int] = None) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
+        async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
+            """ This will be called once per key when find_nearest_nodes is done for a particular node """
+            # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
+            assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
+            unfinished_key_ids.remove(key_id)
+
+            # ensure k nodes and (optionally) exclude self
+            nearest_nodes = [node_id for node_id in nearest_nodes if (not exclude_self or node_id != self.node_id)]
+            store_args = [key_id], [binary_values_by_key_id[key_id]], [expiration_by_key_id[key_id]]
+            store_tasks = {asyncio.create_task(self.protocol.call_store(node_to_endpoint[nearest_node_id], *store_args))
+                           for nearest_node_id in nearest_nodes[:self.num_replicas]}
+            backup_nodes = nearest_nodes[self.num_replicas:]  # used in case previous nodes didn't respond
+
+            # parse responses and issue additional stores if someone fails
+            while store_tasks:
+                finished_store_tasks, store_tasks = await asyncio.wait(store_tasks, return_when=asyncio.FIRST_COMPLETED)
+                for task in finished_store_tasks:
+                    if task.result()[0]:  # if store succeeded
+                        store_ok[id_to_original_key[key_id]] = True
+                        if not await_all_replicas:
+                            store_finished_events[id_to_original_key[key_id]].set()
+                    elif backup_nodes:
+                        store_tasks.add(asyncio.create_task(
+                            self.protocol.call_store(node_to_endpoint[backup_nodes.pop(0)], *store_args)))
+
+                store_finished_events[id_to_original_key[key_id]].set()
+
+        asyncio.create_task(self.find_nearest_nodes(
+            queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
+            found_callback=on_found, exclude_self=exclude_self, **kwargs))
+        await asyncio.wait([evt.wait() for evt in store_finished_events.values()])  # await one (or all) store accepts
+        assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
+        return store_ok
+
+    async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
         """
-        :param key: traverse the DHT and find the value for this key (or return None if it does not exist)
+        Search for a key across DHT and return either first or latest entry.
+        :param key: same key as in node.store(...)
+        :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
+        :param kwargs: parameters forwarded to get_many
+        :returns: (value, expiration time); if value was not found, returns (None, None)
+        """
+        if latest:
+            kwargs["sufficient_expiration_time"] = float('inf')
+        result = await self.get_many([key])
+        return result[key]
+
+    async def get_many(
+            self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
+            num_workers: Optional[int] = None, beam_size: Optional[int] = None
+    ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
+        """
+        :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
             default = time of call, find any value that did not expire by the time of call
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
         :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
-        :returns: value and its expiration time. If nothing is found , returns (None, None).
+        :param num_workers: override for default num_workers, see traverse_dht num_workers param
+        :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
-        key_id = DHTID.generate(key)
+        key_ids = [DHTID.generate(key) for key in keys]
+        id_to_original_key = dict(zip(key_ids, keys))
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
-        latest_value_bytes, latest_expiration, latest_node_id = b'', -float('inf'), None
-        node_to_addr, nodes_checked_for_value, nearest_nodes = dict(), set(), []
-        should_cache = False  # True if found value in DHT that is newer than local value
-
-        # Option A: value can be stored in our local cache
-        maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
-        if maybe_expiration is None:
-            maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
-        if maybe_expiration is not None and maybe_expiration > latest_expiration:
-            latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
-            # TODO(jheuristic) we may want to run background beam search to update our cache
-        nodes_checked_for_value.add(self.node_id)
-
-        # Option B: go beam search the DHT
-        if latest_expiration < sufficient_expiration_time:
+        num_workers = num_workers if num_workers is not None else self.num_workers
+
+        # search metadata
+        unfinished_key_ids = set(key_ids)  # track key ids for which the search is not terminated
+        node_to_addr: Dict[DHTID, Endpoint] = dict()  # global routing table for all queries
+
+        SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
+        latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
+
+        # stage 1: value can be stored in our local cache
+        for key_id in key_ids:
+            maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
+            if maybe_expiration is None:
+                maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
+            if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
+                latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, self.node_id)
+                if maybe_expiration >= sufficient_expiration_time:
+                    unfinished_key_ids.remove(key_id)
+
+        # stage 2: traverse the DHT for any unfinished keys
+        for key_id in unfinished_key_ids:
             node_to_addr.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
-            async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
-                nonlocal latest_value_bytes, latest_expiration, latest_node_id, node_to_addr, nodes_checked_for_value
-                response = await self.protocol.call_find(node_to_addr[node], [key_id])
-                nodes_checked_for_value.add(node)
-                if not response or key_id not in response:
-                    return [], False
+        async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
+            queries = list(queries)
+            response = await self.protocol.call_find(node_to_addr[peer], queries)
+            if not response:
+                return {query: ([], False) for query in queries}
 
-                maybe_value, maybe_expiration, peers = response[key_id]
+            output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
+            for key_id, (maybe_value, maybe_expiration, peers) in response.items():
                 node_to_addr.update(peers)
-                if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
-                should_interrupt = (latest_expiration >= sufficient_expiration_time)
-                return list(peers.keys()), should_interrupt
-
-            nearest_nodes, visited_nodes = await traverse_dht(
-                query_id=key_id, initial_nodes=list(node_to_addr), k_nearest=beam_size, beam_size=beam_size,
-                get_neighbors=get_neighbors, visited_nodes=nodes_checked_for_value)
-            # normally, by this point we will have found a sufficiently recent value in one of get_neighbors calls
-            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it later
-
-        # Option C: didn't find good-enough value in beam search, make a last-ditch effort to find it in unvisited nodes
-        if latest_expiration < sufficient_expiration_time:
-            nearest_unvisited = [node_id for node_id in nearest_nodes if node_id not in nodes_checked_for_value]
-            tasks = [self.protocol.call_find(node_to_addr[node_id], [key_id]) for node_id in nearest_unvisited]
-            pending_tasks = set(tasks)
-            for task in asyncio.as_completed(tasks):
-                pending_tasks.remove(task)
-                if not task.result() or key_id not in task.result():
-                    maybe_value, maybe_expiration, _ = task.result()[key_id]
-                if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value_bytes, latest_expiration = maybe_value, maybe_expiration
-                    if latest_expiration >= sufficient_expiration_time:
+                if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
+                    latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
+                should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
+                output[key_id] = list(peers.keys()), should_interrupt
+            return output
+
+        nearest_nodes_per_query, visited_nodes = await traverse_dht(
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_addr),
+            beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
+            get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
+
+        # stage 3: cache any new results depending on caching parameters
+        for key_id, nearest_nodes in nearest_nodes_per_query.items():
+            latest_value_bytes, latest_expiration, latest_node_id = latest_results[key_id]
+            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it
+            if should_cache and self.cache_locally:
+                self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
+
+            if should_cache and self.cache_nearest:
+                num_cached_nodes = 0
+                for node_id in nearest_nodes:
+                    if node_id == latest_node_id:
+                        continue
+                    asyncio.create_task(self.protocol.call_store(
+                        node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
+                    num_cached_nodes += 1
+                    if num_cached_nodes >= self.cache_nearest:
                         break
-            for task in pending_tasks:
-                task.close()
-            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it later
-
-        # step 4: we have not found entry with sufficient_expiration_time, but we may have found *something* older
-        if should_cache and self.cache_locally:
-            self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
-        if should_cache and self.cache_nearest:
-            num_cached_nodes = 0
-            for node_id in nearest_nodes:
-                if node_id == latest_node_id:
-                    continue
-                asyncio.create_task(self.protocol.call_store(
-                    node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
-                num_cached_nodes += 1
-                if num_cached_nodes >= self.cache_nearest:
-                    break
-        if latest_expiration != -float('inf'):
-            return self.serializer.loads(latest_value_bytes), latest_expiration
-        else:
-            return None, None
+
+        # stage 4: deserialize data and assemble function output
+        find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
+        for key_id, (latest_value_bytes, latest_expiration, _) in latest_results.items():
+            if latest_expiration != -float('inf'):
+                find_result[id_to_original_key[key_id]] = self.serializer.loads(latest_value_bytes), latest_expiration
+            else:
+                find_result[id_to_original_key[key_id]] = None, None
+        return find_result
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
         while period is not None:  # if None run once, otherwise run forever
             refresh_time = get_dht_time()
-            staleness_threshold = refresh_time - self.staleness_timeout
+            staleness_threshold = refresh_time - period
             stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
                              if bucket.last_updated < staleness_threshold]
             for bucket in stale_buckets:

+ 3 - 2
hivemind/dht/protocol.py

@@ -1,10 +1,11 @@
+""" RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
 from __future__ import annotations
 
 import asyncio
 import heapq
 import os
 import urllib.parse
-from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
+from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union, Collection
 from warnings import warn
 
 import grpc
@@ -161,7 +162,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
         return response
 
-    async def call_find(self, peer: Endpoint, keys: Sequence[DHTID]) -> \
+    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> \
             Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and

+ 4 - 3
hivemind/dht/routing.py

@@ -1,3 +1,4 @@
+""" Utlity data structures to represent DHT nodes (peers), data keys, and routing tables. """
 from __future__ import annotations
 
 import hashlib
@@ -15,13 +16,13 @@ from ..utils import Endpoint, PickleSerializer
 
 class RoutingTable:
     """
-    A data structure that contains DHT peers bucketed according to their distance to node_id
+    A data structure that contains DHT peers bucketed according to their distance to node_id.
+    Follows Kademlia routing table as described in https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
 
     :param node_id: node id used to measure distance
     :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
     :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
-    :note: you can find a more detailed docstring for Node class, see node.py
-    :note: kademlia paper refers to https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
+    :note: you can find a more detailed description of parameters in DHTNode, see node.py
     """
 
     def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):

+ 0 - 62
hivemind/dht/search.py

@@ -1,62 +0,0 @@
-import heapq
-from typing import Collection, Callable, Tuple, List, Awaitable, Set
-from warnings import warn
-
-from .routing import DHTID
-
-
-async def traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], k_nearest: int, beam_size: int,
-                       get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
-                       visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
-    """
-    Asynchronous beam search over the DHT. Not meant to be called by the user, please use DHTNode.store/get instead.
-    Traverse the DHT graph using get_neighbors function, find up to k_nearest nodes according to DHTID.xor_distance.
-    Approximate time complexity: O(T * log T) where T = (path_to_true_nearest + beam_size) * mean_num_neighbors
-
-    :param query_id: search query, find k_nearest neighbors of this DHTID
-    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
-    :param k_nearest: find up to this many nearest neighbors. If there are less nodes in the DHT, return all nodes
-    :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
-        Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
-
-    :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria.
-        async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool
-        If should_continue is False, beam search will halt and return k_nearest of whatever it found by then.
-
-    :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest
-    :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes)
-    """
-    if beam_size < k_nearest:
-        warn(f"beam search: beam_size({beam_size}) is too small, beam search may fail to find k neighbors.")
-    visited_nodes = set(visited_nodes)  # note: copy visited_nodes because we will add more nodes to this collection.
-    initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
-    if not initial_nodes:
-        return [], visited_nodes
-
-    unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
-    heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size
-
-    nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
-    heapq.heapify(nearest_nodes)  # farthest-first heap of size beam_size, used for early-stopping and to select results
-    while len(nearest_nodes) > beam_size:
-        heapq.heappop(nearest_nodes)
-
-    visited_nodes |= set(initial_nodes)
-    upper_bound = -nearest_nodes[0][0]  # distance to farthest element that is still in beam
-    was_interrupted = False  # will set to True if host triggered beam search to stop via get_neighbors
-
-    while (not was_interrupted) and len(unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound:
-        _, node_id = heapq.heappop(unvisited_nodes)  # note: this  --^ is the smallest element in heap (see heapq)
-        neighbors, was_interrupted = await get_neighbors(node_id)
-        neighbors = [node_id for node_id in neighbors if node_id not in visited_nodes]
-        visited_nodes.update(neighbors)
-
-        for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)):
-            if distance <= upper_bound or len(nearest_nodes) < beam_size:
-                heapq.heappush(unvisited_nodes, (distance, neighbor_id))
-
-                heapq_add_or_replace = heapq.heappush if len(nearest_nodes) < beam_size else heapq.heappushpop
-                heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
-                upper_bound = -nearest_nodes[0][0]  # distance to beam_size-th nearest element found so far
-
-    return [node_id for _, node_id in heapq.nlargest(k_nearest, nearest_nodes)], visited_nodes

+ 232 - 0
hivemind/dht/traverse.py

@@ -0,0 +1,232 @@
+""" Utility functions for crawling DHT nodes, used to get and store keys in a DHT """
+import asyncio
+import heapq
+from collections import Counter
+from warnings import warn
+from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
+from .routing import DHTID
+
+ROOT = 0  # alias for heap root
+
+
+async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
+                              get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
+                              visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
+    """
+    Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
+
+    :note: This is a simplified (but working) algorithm provided for documentation purposes. Actual DHTNode uses
+       `traverse_dht` - a generalization of this this algorithm that allows multiple queries and concurrent workers.
+
+    :param query_id: search query, find k_nearest neighbors of this DHTID
+    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
+    :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
+        Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
+    :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria.
+        async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool
+        If should_continue is False, beam search will halt and return k_nearest of whatever it found by then.
+    :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest
+    :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes)
+    """
+    visited_nodes = set(visited_nodes)  # note: copy visited_nodes because we will add more nodes to this collection.
+    initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
+    if not initial_nodes:
+        return [], visited_nodes
+
+    unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
+    heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size
+
+    nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
+    heapq.heapify(nearest_nodes)  # farthest-first heap of size beam_size, used for early-stopping and to select results
+    while len(nearest_nodes) > beam_size:
+        heapq.heappop(nearest_nodes)
+
+    visited_nodes |= set(initial_nodes)
+    upper_bound = -nearest_nodes[0][0]  # distance to farthest element that is still in beam
+    was_interrupted = False  # will set to True if host triggered beam search to stop via get_neighbors
+
+    while (not was_interrupted) and len(unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound:
+        _, node_id = heapq.heappop(unvisited_nodes)  # note: this  --^ is the smallest element in heap (see heapq)
+        neighbors, was_interrupted = await get_neighbors(node_id)
+        neighbors = [node_id for node_id in neighbors if node_id not in visited_nodes]
+        visited_nodes.update(neighbors)
+
+        for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)):
+            if distance <= upper_bound or len(nearest_nodes) < beam_size:
+                heapq.heappush(unvisited_nodes, (distance, neighbor_id))
+
+                heapq_add_or_replace = heapq.heappush if len(nearest_nodes) < beam_size else heapq.heappushpop
+                heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
+                upper_bound = -nearest_nodes[0][0]  # distance to beam_size-th nearest element found so far
+
+    return [node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes)], visited_nodes
+
+
+async def traverse_dht(
+        queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
+        get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[List[DHTID], bool]]]],
+        found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
+        await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
+) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
+    """
+    Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
+    The algorithm can reuse intermediate results from each query to speed up search for other (similar) queries.
+
+    :param queries: a list of search queries, find beam_size neighbors for these DHTIDs
+    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
+    :param beam_size: beam search will not give up until it visits this many nearest nodes (to query_id) from the heap
+    :param num_workers: run up to this many concurrent get_neighbors requests, each querying one peer for neighbors.
+        When selecting a peer to request neighbors from, workers try to balance concurrent exploration across queries.
+        A worker will expand the nearest candidate to a query with least concurrent requests from other workers.
+        If several queries have the same number of concurrent requests, prefer the one with nearest XOR distance.
+
+    :param queries_per_call: workers can pack up to this many queries in one get_neighbors call. These queries contain
+        the primary query (see num_workers above) and up to `queries_per_call - 1` nearest unfinished queries.
+
+    :param get_neighbors: A function that requests a given peer to find nearest neighbors for multiple queries
+        async def get_neighbors(peer, queries) -> {query1: ([nearest1, nearest2, ...], False), query2: ([...], True)}
+        For each query in queries, return nearest neighbors (known to a given peer) and a boolean "should_stop" flag
+        If should_stop is True, traverse_dht will no longer search for this query or request it from other peers.
+        The search terminates iff each query is either stopped via should_stop or finds beam_size nearest nodes.
+
+    :param found_callback: if specified, call this callback for each finished query the moment it finishes or is stopped
+        More specifically, run asyncio.create_task(found_found_callback(query, nearest_to_query, visited_for_query))
+        Using this callback allows one to process results faster before traverse_dht is finishes for all queries.
+
+    :param await_all_tasks: if True, wait for all tasks to finish before returning, otherwise returns after finding
+        nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
+
+    :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
+    :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
+        for reference. That function implements a special case of traverse_dht with a single query and one worker.
+
+    :returns: a dict of nearest nodes, and another dict of visited nodes
+        nearest nodes: { query -> a list of up to beam_size nearest nodes, ordered nearest-first }
+        visited nodes: { query -> a set of all nodes that received requests for a given query }
+    """
+    if len(queries) == 0:
+        return {}, dict(visited_nodes)
+
+    unfinished_queries = set(queries)                             # all queries that haven't triggered finish_search yet
+    candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}    # heap: unvisited nodes, ordered nearest-to-farthest
+    nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {}      # heap: top-k nearest nodes, farthest-to-nearest
+    known_nodes: Dict[DHTID, Set[DHTID]] = {}                     # all nodes ever added to the heap (for deduplication)
+    visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes)  # where we requested get_neighbors for a given query
+    pending_tasks = set()                                         # all active tasks (get_neighbors and found_callback)
+    active_workers = Counter({q: 0 for q in queries})             # count workers that search for this query
+
+    search_finished_event = asyncio.Event()  # used to immediately stop all workers when the search is finished
+    heap_updated_event = asyncio.Event()  # if a worker has no nodes to explore, it will await other workers
+    heap_updated_event.set()
+
+    # initialize data structures
+    for query in queries:
+        distances = query.xor_distance(initial_nodes)
+        candidate_nodes[query] = list(zip(distances, initial_nodes))
+        nearest_nodes[query] = list(zip([-d for d in distances], initial_nodes))
+        heapq.heapify(candidate_nodes[query])
+        heapq.heapify(nearest_nodes[query])
+        while len(nearest_nodes[query]) > beam_size:
+            heapq.heappop(nearest_nodes[query])
+        known_nodes[query] = set(initial_nodes)
+        visited_nodes[query] = set(visited_nodes.get(query, ()))
+
+    def heuristic_priority(heap_query: DHTID):
+        """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
+        if len(candidate_nodes[heap_query]) == 0:
+            return float('inf'), float('inf')
+        else:  # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
+            return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
+
+    def upper_bound(query: DHTID):
+        """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
+        return -nearest_nodes[query][ROOT][0] if len(nearest_nodes[query]) >= beam_size else float('inf')
+
+    def finish_search(query):
+        """ Remove query from a list of targets """
+        unfinished_queries.remove(query)
+        if len(unfinished_queries) == 0:
+            search_finished_event.set()
+        if found_callback:
+            nearest_neighbors = [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
+            pending_tasks.add(asyncio.create_task(found_callback(query, nearest_neighbors, set(visited_nodes))))
+
+    async def worker():
+        while unfinished_queries:
+            # select the heap based on priority
+            chosen_query: DHTID = min(unfinished_queries, key=heuristic_priority)
+
+            if len(candidate_nodes[chosen_query]) == 0:  # if there are no peers to explore...
+                other_workers_pending = active_workers.most_common(1)[0][1] > 0
+                if other_workers_pending:  # ... wait for other workers (if any) or add more peers
+                    heap_updated_event.clear()
+                    await heap_updated_event.wait()
+                    continue
+                else:  # ... or if there is no hope of new nodes, finish search immediately
+                    for query in list(unfinished_queries):
+                        finish_search(query)
+                    break
+
+            # select vertex to be explored
+            chosen_distance_to_query, chosen_peer = heapq.heappop(candidate_nodes[chosen_query])
+            if chosen_peer in visited_nodes[chosen_query]:
+                continue
+            if chosen_distance_to_query > upper_bound(chosen_query):
+                finish_search(chosen_query)
+                continue
+
+            # find additional queries to pack in the same request
+            possible_additional_queries = [query for query in unfinished_queries
+                                           if query != chosen_query and chosen_peer not in visited_nodes[query]]
+            queries_to_call = [chosen_query] + heapq.nsmallest(
+                queries_per_call - 1, possible_additional_queries, key=chosen_peer.xor_distance)
+
+            # update priorities for subsequent workers
+            active_workers.update(queries_to_call)
+            for query_to_call in queries_to_call:
+                visited_nodes[query_to_call].add(chosen_peer)
+
+            # get nearest neighbors (over network) and update search heaps. Abort if search finishes early
+            get_neighbors_task = asyncio.create_task(get_neighbors(chosen_peer, queries_to_call))
+            pending_tasks.add(get_neighbors_task)
+            await asyncio.wait([get_neighbors_task, search_finished_event.wait()], return_when=asyncio.FIRST_COMPLETED)
+            if search_finished_event.is_set():
+                break  # other worker triggered finish_search, we exit immediately
+            pending_tasks.remove(get_neighbors_task)
+
+            # add nearest neighbors to their respective heaps
+            for query, (neighbors_for_query, should_stop) in get_neighbors_task.result().items():
+                if should_stop and (query in unfinished_queries):
+                    finish_search(query)
+                if query not in unfinished_queries:
+                    continue  # either we finished search or someone else did while we awaited
+                for neighbor in neighbors_for_query:
+                    if neighbor not in known_nodes[query]:
+                        known_nodes[query].add(neighbor)
+                        distance = query.xor_distance(neighbor)
+                        if distance <= upper_bound(query) or len(nearest_nodes[query]) < beam_size:
+                            heapq.heappush(candidate_nodes[query], (distance, neighbor))
+                            if len(nearest_nodes[query]) < beam_size:
+                                heapq.heappush(nearest_nodes[query], (-distance, neighbor))
+                            else:
+                                heapq.heappushpop(nearest_nodes[query], (-distance, neighbor))
+
+            # we finished processing a request, update priorities for other workers
+            active_workers.subtract(queries_to_call)
+            heap_updated_event.set()
+
+    # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
+    await asyncio.wait([asyncio.create_task(worker()) for _ in range(num_workers)],
+                       return_when=asyncio.FIRST_COMPLETED)  # first worker finishes when the search is over
+    assert len(unfinished_queries) == 0 and search_finished_event.is_set()
+
+    if await_all_tasks:
+        await asyncio.gather(*pending_tasks)
+
+    nearest_neighbors_per_query = {
+        query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
+        for query in queries
+    }
+    return nearest_neighbors_per_query, visited_nodes
+
+

+ 6 - 6
hivemind/runtime/expert_backend.py

@@ -18,12 +18,12 @@ class ExpertBackend(nn.Module):
 
     :param expert: nn.Module to be wrapped into a backend. Arbitrary pytorch module with a few limitations:
 
-        - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
-        - All \*args, \*\*kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
-        - We recommend using experts that are ~invariant to the order in which they process batches
-        - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want to ensure consistency,
-            you should explicitly register these random variables as model outputs, so that they are sent back to the client.
-            See hivemind.utils.custom_layers.DeterministicDropout for an example
+     - Experts must always receive the same set of \*args and \*\*kwargs and produce output tensors of same type
+     - All args, kwargs and outputs must be **tensors** where 0-th dimension represents to batch size
+     - We recommend using experts that are ~invariant to the order in which they process batches
+     - Using randomness (e.g. Dropout) leads to different samples at forward and backward. If you want consistency,
+        you should explicitly register these random variables as model inputs or outputs.
+        See hivemind.utils.custom_layers.DeterministicDropout for an example
 
     :param opt: torch optimizer to be applied on every backward call
     :param args_schema: description of positional arguments to expert.forward, list of BatchTensorProto

+ 1 - 1
hivemind/utils/serializer.py

@@ -57,7 +57,7 @@ class MSGPackSerializer(SerializerBase):
 
     @staticmethod
     def dumps(obj: object) -> bytes:
-        return umsgpack.dumps(obj, use_bin_type=False)
+        return umsgpack.dumps(obj, use_bin_type=False) # TODO strict https://github.com/msgpack/msgpack-python/pull/158
 
     @staticmethod
     def loads(buf: bytes) -> object:

+ 91 - 97
tests/benchmark_dht.py

@@ -1,104 +1,98 @@
-import argparse
 import time
-import asyncio
-import multiprocessing as mp
+import argparse
 import random
-
+from typing import Tuple
+from warnings import warn
 import hivemind
-from typing import List, Dict
-
-from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
-
-
-def run_benchmark_node(node_id, port, peers, ready: mp.Event, request_perod,
-                       expiration_time, wait_before_read, time_to_test, statistics: mp.Queue, dht_loaded: mp.Event):
-    if asyncio.get_event_loop().is_running():
-        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-    loop = asyncio.new_event_loop()
-    asyncio.set_event_loop(loop)
-    node = DHTNode(node_id, port, initial_peers=peers)
-    await_forever = hivemind.run_forever(asyncio.get_event_loop().run_forever)
-    ready.set()
-    dht_loaded.wait()
-    start = time.perf_counter()
-    while time.perf_counter() < start + time_to_test:
-        query_id = DHTID.generate()
-        store_value = random.randint(0, 256)
-
-        store_time = time.perf_counter()
-        success_store = asyncio.run_coroutine_threadsafe(
-            node.store(query_id, store_value, get_dht_time() + expiration_time), loop).result()
-        store_time = time.perf_counter() - store_time
-        if success_store:
-            time.sleep(wait_before_read)
-            get_time = time.perf_counter()
-            get_value, get_time_expiration = asyncio.run_coroutine_threadsafe(node.get(query_id), loop).result()
-            get_time = time.perf_counter() - get_time
-            success_get = (get_value == store_value)
-            statistics.put((success_store, store_time, success_get, get_time))
-        else:
-            statistics.put((success_store, store_time, None, None))
-    await_forever.result()  # process will exit only if event loop broke down
+from tqdm import trange
+
+from test_utils import increase_file_limit
+
+
+def random_endpoint() -> Tuple[str, int]:
+    return (f"{random.randint(0, 256)}.{random.randint(0, 256)}."
+            f"{random.randint(0, 256)}.{random.randint(0, 256)}", random.randint(0, 65535))
+
+
+def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
+                  wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration_time: float):
+    old_expiration_time, hivemind.DHT.EXPIRATION = hivemind.DHT.EXPIRATION, expiration_time
+    random.seed(random_seed)
+
+    print("Creating peers...")
+    peers = []
+    for _ in trange(num_peers):
+        neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
+        peer = hivemind.DHT(*neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
+        peers.append(peer)
+
+    store_peer, get_peer = peers[-2:]
+
+    expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
+                           for _ in range(num_experts)))
+    print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
+    random.shuffle(expert_uids)
+
+    print(f"Storing peers to dht in batches of {expert_batch_size}...")
+    successful_stores = total_stores = total_store_time = 0
+    benchmark_started = time.perf_counter()
+    endpoints = []
+
+    for start in trange(0, num_experts, expert_batch_size):
+        store_start = time.perf_counter()
+        endpoints.append(random_endpoint())
+        success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], *endpoints[-1])
+        total_store_time += time.perf_counter() - store_start
+
+        total_stores += len(success_list)
+        successful_stores += sum(success_list)
+        time.sleep(wait_after_request)
+
+    print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
+    print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
+    time.sleep(wait_before_read)
+
+    if time.perf_counter() - benchmark_started > expiration_time:
+        warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
+
+    successful_gets = total_get_time = 0
+
+    for start in trange(0, len(expert_uids), expert_batch_size):
+        get_start = time.perf_counter()
+        get_result = get_peer.get_experts(expert_uids[start: start + expert_batch_size])
+        total_get_time += time.perf_counter() - get_start
+
+        for i, expert in enumerate(get_result):
+            if expert is not None and expert.uid == expert_uids[start + i] \
+                    and (expert.host, expert.port) == endpoints[start // expert_batch_size]:
+                successful_gets += 1
+
+    if time.perf_counter() - benchmark_started > expiration_time:
+        warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
+
+    print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
+    print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
+
+    alive_peers = [peer.is_alive() for peer in peers]
+    print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
+    hivemind.DHT.EXPIRATION = old_expiration_time
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument('--num_nodes', type=int, default=20, required=False)
-    parser.add_argument('--request_perod', type=float, default=2, required=False)
-    parser.add_argument('--expiration_time', type=float, default=10, required=False)
-    parser.add_argument('--wait_before_read', type=float, default=1, required=False)
-    parser.add_argument('--time_to_test', type=float, default=10, required=False)
-    args = parser.parse_args()
-
-    statistics = mp.Queue()
-    dht: Dict[Endpoint, DHTID] = {}
-    processes: List[mp.Process] = []
-
-    num_nodes = args.num_nodes
-    request_perod = args.request_perod
-    expiration_time = args.expiration_time
-    wait_before_read = args.wait_before_read
-    time_to_test = args.time_to_test
-
-    dht_loaded = mp.Event()
-    for i in range(num_nodes):
-        node_id = DHTID.generate()
-        port = hivemind.find_open_port()
-        peers = random.sample(dht.keys(), min(len(dht), 5))
-        ready = mp.Event()
-        proc = mp.Process(target=run_benchmark_node, args=(node_id, port, peers, ready, request_perod,
-                                                           expiration_time, wait_before_read, time_to_test, statistics,
-                                                           dht_loaded), daemon=True)
-        proc.start()
-        ready.wait()
-        processes.append(proc)
-        dht[(LOCALHOST, port)] = node_id
-    dht_loaded.set()
-    time.sleep(time_to_test)
-    success_store = 0
-    all_store = 0
-    time_store = 0
-    success_get = 0
-    all_get = 0
-    time_get = 0
-    while not statistics.empty():
-        success_store_i, store_time_i, success_get_i, get_time_i = statistics.get()
-        all_store += 1
-        time_store += store_time_i
-        if success_store_i:
-            success_store += 1
-            all_get += 1
-            success_get += 1 if success_get_i else 0
-            time_get += get_time_i
-    alive_nodes_count = 0
-    loop = asyncio.new_event_loop()
-    node = DHTNode(loop=loop)
-    for addr, port in dht:
-        if loop.run_until_complete(node.protocol.call_ping((addr, port))) is not None:
-            alive_nodes_count += 1
-    print("store success rate: ", success_store / all_store)
-    print("mean store time: ", time_store / all_store)
-    print("get success rate: ", success_get / all_get)
-    print("mean get time: ", time_get / all_get)
-    print("death rate: ", (num_nodes - alive_nodes_count) / num_nodes)
+    parser.add_argument('--num_peers', type=int, default=32, required=False)
+    parser.add_argument('--initial_peers', type=int, default=1, required=False)
+    parser.add_argument('--num_experts', type=int, default=256, required=False)
+    parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
+    parser.add_argument('--expiration_time', type=float, default=300, required=False)
+    parser.add_argument('--wait_after_request', type=float, default=0, required=False)
+    parser.add_argument('--wait_before_read', type=float, default=0, required=False)
+    parser.add_argument('--wait_timeout', type=float, default=5, required=False)
+    parser.add_argument('--random_seed', type=int, default=random.randint(1, 1000))
+    parser.add_argument('--increase_file_limit', action="store_true")
+    args = vars(parser.parse_args())
+
+    if args.pop('increase_file_limit', False):
+        increase_file_limit()
+
+    benchmark_dht(**args)

+ 2 - 8
tests/benchmark_throughput.py

@@ -1,12 +1,11 @@
 import argparse
 import multiprocessing as mp
 import random
-import resource
 import sys
 import time
 
 import torch
-from test_utils import layers, print_device_info
+from test_utils import layers, print_device_info, increase_file_limit
 from hivemind import find_open_port
 
 import hivemind
@@ -139,12 +138,7 @@ if __name__ == "__main__":
         benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192, num_handlers=32,
                              num_clients=512, num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'ffn_massive':
-        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
-        try:
-            print("Setting open file limit to soft={}, hard={}".format(max(soft, 2 ** 15), max(hard, 2 ** 15)))
-            resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 2 ** 15), max(hard, 2 ** 15)))
-        except:
-            print("Could not increase open file limit, currently at soft={}, hard={}".format(soft, hard))
+        increase_file_limit()
         benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
                              max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
     elif args.preset == 'minimalistic':

+ 20 - 10
tests/test_dht.py

@@ -153,7 +153,7 @@ def run_node(node_id, peers, status_pipe: mp.Pipe):
         loop.run_forever()
 
 
-def test_dht():
+def test_dht_node():
     # create dht with 50 nodes + your 51-st node
     dht: Dict[Endpoint, DHTID] = {}
     processes: List[mp.Process] = []
@@ -177,13 +177,13 @@ def test_dht():
         me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
 
         # test 1: find self
-        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))
-        assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
+        nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
+        assert len(nearest) == 1 and nearest[me.node_id] == f"{LOCALHOST}:{me.port}"
 
         # test 2: find others
         for i in range(10):
             ref_endpoint, query_id = random.choice(list(dht.items()))
-            nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=query_id, k_nearest=1))
+            nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
             assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
 
         # test 3: find neighbors to random nodes
@@ -196,11 +196,11 @@ def test_dht():
             k_nearest = random.randint(1, 20)
             exclude_self = random.random() > 0.5
             nearest = loop.run_until_complete(
-                me.find_nearest_nodes(key_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
+                me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
             nearest_nodes = list(nearest)  # keys from ordered dict
 
             assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
+            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
             assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
 
             ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
@@ -223,15 +223,16 @@ def test_dht():
         assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
 
         # test 4: find all nodes
-        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=DHTID.generate(), k_nearest=len(dht) + 100))
+        dummy = DHTID.generate()
+        nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
         assert len(nearest) == len(dht) + 1
         assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
 
         # test 5: node without peers
         other_node = loop.run_until_complete(DHTNode.create())
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
-        assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
-        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy]))[dummy]
+        assert len(nearest) == 1 and nearest[other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
         assert len(nearest) == 0
 
         # test 6 store and get value
@@ -242,6 +243,15 @@ def test_dht():
             assert expiration_time == true_time, "Wrong time"
             assert val == ["Value", 10], "Wrong value"
 
+        # test 7: bulk store and bulk get
+        keys = 'foo', 'bar', 'baz', 'zzz'
+        values = 3, 2, 'batman', [1, 2, 3]
+        store_ok = loop.run_until_complete(me.store_many(keys, values, expiration=get_dht_time() + 999))
+        assert all(store_ok.values()), "failed to store one or more keys"
+        response = loop.run_until_complete(me.get_many(keys[::-1]))
+        for key, value in zip(keys, values):
+            assert key in response and response[key][0] == value
+
         test_success.set()
 
     tester = mp.Process(target=_tester, daemon=True)

+ 1 - 1
tests/test_moe.py

@@ -66,7 +66,7 @@ def test_determinism():
 
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHT(port=hivemind.find_open_port(), start=True)
+        dht = hivemind.DHT(start=True)
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert')

+ 13 - 0
tests/test_utils/__init__.py

@@ -1,3 +1,5 @@
+from warnings import warn
+
 import torch
 
 
@@ -12,3 +14,14 @@ def print_device_info(device=None):
         print('Memory Usage:')
         print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1), 'GB')
         print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024 ** 3, 1), 'GB')
+
+
+def increase_file_limit(new_soft=2 ** 15, new_hard=2 ** 15):
+    """ Increase the maximum number of open files. On Linux, this allows spawning more processes/threads. """
+    try:
+        import resource  # note: local import to avoid ImportError for those who don't have it
+        soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
+        print(f"Increasing file limit - soft {soft}=>{new_soft}, hard {hard}=>{new_hard}")
+        return resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, new_soft), max(hard, new_hard)))
+    except Exception as e:
+        warn(f"Failed to increase file limit: {e}")

+ 3 - 3
tests/test_utils/run_server.py

@@ -46,16 +46,16 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
             dht_root = hivemind.DHT(
-                *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
+                *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{root_port or hivemind.find_open_port()}", start=True)
             print(f"Initializing DHT with port {dht_root.port}")
-            initial_peers = (('localhost', dht_root.port),)
+            initial_peers = [f"{hivemind.LOCALHOST}:{dht_root.port}"]
         else:
             print("Bootstrapping dht with peers:", initial_peers)
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
         dht = hivemind.DHT(
-            *initial_peers, port=dht_port or hivemind.find_open_port(), start=True)
+            *initial_peers, listen_on=f"{hivemind.LOCALHOST}:{dht_port or hivemind.find_open_port()}", start=True)
         if verbose:
             print(f"Running dht node on port {dht.port}")