Просмотр исходного кода

Re-write DHT (#37)

* documentation

* DHTNodeID and tests

* wip: dht

* define depth according to 4.2; split routing tests

* better docs on depth modulo

* dht notuml

* move TODOs to PR

* allow excluding any id, add kademlia rpc, add expiration time

* fluff

* kademliaprotocol no longer uses DHTNode

* kademliaprotocol no longer uses DHTNode

* do not override default to_bytes

* call parent init

* disable notimplementederror for test purposes

* add tests for dhtid <=> bytes conversion

* send binary dhtid over rpc

* handle tuple to list for older msgpack

* add basic test for kademlia rpc

* convert key to bytes

* dhtnode bootstrap

* basic dht crawling based on beam search

* fix protocol tests

* minor rename

* * bucket staleness is now handled by DHTNode, not KademliaProtocol
* pinging nodes is now performed by protocol, not routing table
* protocol now updates routing table asynchronously (reduces rpc latency under load)
* routing table will now avoid pinging node if this node is already being ping-ed

* fallback to standard eviction policy

* * fix bug with nearest neighbors for incomplete buckets
* add test for nearest neighbors with incomplete buckets
* make DHTNode asynchronous, as expected

* temp fix circleci

* temp fix circleci

* Implemented True logic for LocalStorage

* Added tests for LocalStorage class

* finish merging

* remove debugprint

* add beam search tests

* test: check that beam search always finds top-1 node

* Change import order

* fixed None maxsize

* fixed expiration time

* removed useless test call

* moved localstorage tests to test_kademlia

* Local storage (#44)

* Implemented True logic for LocalStorage

* Added tests for LocalStorage class

* Change import order

* fixed None maxsize

* fixed expiration time

* removed useless test call

* moved localstorage tests to test_kademlia

* Implemented get and store node functions

* Fixed value order in node.get

* Added dht benchmark

* Fixed node store/get test

* close connection before returning

* a more definitive find_open_port

* fix test for neighbor lookups

* remove ipynb-specific asyncio calls

* move beam_search to a separate file

* use preferred way to create task

* fix warnings

* make DHTNode get/store work on arbitrary keys, not just DHTID

* Local storage with caching (#46)

* Implemented better removal of old  keys from LocalStorage

* merged specialize_dht and fixed tests

* Fixed unixtime

* Fixed indexError in LocalStorage remove_outdated

* Fixed indexError in LocalStorage remove_outdated

* Fixed error, when you change exp. time, but in table it doesnt changes

* Fixed error with keeping expired value

* Fixed error with keeping expired value

* Added tests for LocalStorage

* Implemented cache in LocalStorage

* Added test for caching in LocalStorage

* Fixing cache cleaning for LocalStorage

* Renamed maxsize to cache size

* Fixed maxsize to cachesize

* Removed cache. Separate storage for cache

* Removed cache tests

* Implemented caching in KademliaProtocol

* simplify public DHT interface

* switch from time.monotonic to unixtime, allow changing it in a centralized function

* remove kademlia requirement, use rpcudp directly

* rename, revert to run_coroutine_threadsafe

* separate wait and timeout args

* add test for hivemind dht

* handle responses to cancelled tasks

* moved comments to issue

Co-authored-by: Vsevolod-pl <vsevolod-pl@yandex.ru>
Co-authored-by: Vsevolod-pl <Vsevolod-pl@users.noreply.github.com>
justheuristic 5 лет назад
Родитель
Сommit
d961ceb6ae

BIN
docs/_static/dht.odp


BIN
docs/_static/dht.pdf


BIN
docs/_static/dht.png


+ 2 - 2
docs/user/quickstart.md

@@ -35,11 +35,11 @@ do something complex with it, please contact us by opening an issue (less prefer
 - **`Runtime`** (`hivemind/runtime/__init__.py`) aggregates batches
   and performs inference/training of experts according to their priority.
 - **`Server`** (`hivemind/server/__init__.py`) wraps runtime and
-  periodically uploads experts into `DHTNode`.
+  periodically uploads experts into `DHT`.
 
 **DHT:**
 
-- **`DHTNode`**(`hivemind/dht/__init__.py`) is a node of
+- **`DHT`**(`hivemind/dht/__init__.py`) is a node of
   Kademlia-based DHT that stores metadata used by trainer and runtime.
 
 ## Limitations

+ 1 - 1
hivemind/client/moe.py

@@ -26,7 +26,7 @@ class RemoteMixtureOfExperts(nn.Module):
     :param grid_size: hivemind dimensions that form expert uid (see below)
     :param uid_prefix: common prefix for all expert uids
      expert uid follows the pattern {uid_prefix}.{0...grid_size[0]}.{0...grid_size[1]}...{0...grid_size[-1]}
-    :param dht: DHTNode where the experts reside
+    :param dht: DHT where the experts reside
     :param num_workers: number of threads for parallel dht operation
     :param k_best: queries this many experts with highest scores
     :param k_min: makes sure at least this many experts returned output

+ 88 - 59
hivemind/dht/__init__.py

@@ -1,36 +1,57 @@
+"""
+This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
+ * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
+ * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
+ * class KademliaProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
+
+The code in this module is a modified version of https://github.com/bmuller/kademlia
+Brian, if you're reading this: THANK YOU! you're awesome :)
+"""
 import asyncio
-import datetime
 import multiprocessing as mp
 import warnings
 from typing import Tuple, List, Optional
 
-from kademlia.network import Server
+from .node import DHTNode, DHTID, DHTExpiration
+from .routing import get_dht_time
 
-from hivemind.client import RemoteExpert
-from hivemind.utils import run_forever, SharedFuture, PickleSerializer
+from ..client import RemoteExpert
+from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_background
 
 
-class DHTNode(mp.Process):
+class DHT(mp.Process):
+    """
+    A high-level interface to hivemind DHT. Runs a dht node in a background process.
+    :param initial_peers: one or multiple pairs of (host, port) pointing to active DHT peers. Default: no peers
+    :param port: a port where DHT will listen to incoming connections. Defaults to hivemind.utils.find_open_port
+    :param start: if True, automatically starts the background process on creation. Otherwise await manual start
+    :param daemon: if True, the background process is marked as daemon and automatically terminated after main process
+    :param node_params: any other params will be forwarded to DHTNode upon creation
+    """
     UID_DELIMETER = '.'  # splits expert uids over this delimeter
-    HEARTBEAT_EXPIRATION = 120  # expert is inactive iff it fails to post timestamp for *this many seconds*
+    EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
 
-    def __init__(self, *initial_peers: Tuple[str, int], port=8081, start=False, daemon=True):
+    def __init__(self, *initial_peers: Tuple[Hostname, Port], port: Optional[Port] = None,
+                 start: bool, daemon: bool = True, **node_params):
         super().__init__()
-        self.port, self.initial_peers = port, initial_peers
+        port = find_open_port() if port is None else port
+        self.node: Optional[DHTNode] = None  # to be initialized in self.run
+        self.port, self.initial_peers, self.node_params = port, initial_peers, node_params
         self._pipe, self.pipe = mp.Pipe(duplex=False)
         self.ready = mp.Event()
-        self.server = Server()
         self.daemon = daemon
         if start:
             self.run_in_background(await_ready=True)
 
     def run(self) -> None:
+        if asyncio.get_event_loop().is_running():
+            asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
-        loop.run_until_complete(self.server.listen(self.port))
-        loop.run_until_complete(self.server.bootstrap(self.initial_peers))
-        run_forever(loop.run_forever)
+
+        self.node = DHTNode(initial_peers=list(self.initial_peers), port=self.port, **self.node_params)
+        run_in_background(loop.run_forever)
         self.ready.set()
 
         while True:
@@ -39,7 +60,7 @@ class DHTNode(mp.Process):
 
     def run_in_background(self, await_ready=True, timeout=None):
         """
-        Starts DHTNode in a background process. if await_ready, this method will wait until background dht
+        Starts DHT in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
@@ -53,98 +74,106 @@ class DHTNode(mp.Process):
         else:
             warnings.warn("DHT shutdown has no effect: dht process is already not alive")
 
-    def get_experts(self, uids: List[str], heartbeat_expiration=HEARTBEAT_EXPIRATION) -> List[Optional[RemoteExpert]]:
-        """ Find experts across DHT using their ids; Return a list of [RemoteExpert if found else None]"""
+    def get_experts(self, uids: List[str], expiration=None) -> List[Optional[RemoteExpert]]:
+        """
+        :param uids: find experts with these ids from across the DHT
+        :param expiration: returns experts that expire no sooner than this (based on get_dht_time), default = now
+        :returns: a list of [RemoteExpert if found else None]
+        """
         future, _future = SharedFuture.make_pair()
-        self.pipe.send(('_get_experts', [], dict(uids=uids, heartbeat_expiration=heartbeat_expiration, future=_future)))
+        self.pipe.send(('_get_experts', [], dict(uids=uids, expiration=expiration, future=_future)))
         return future.result()
 
-    def _get_experts(self, uids: List[str], heartbeat_expiration: float, future: SharedFuture):
+    def _get_experts(self, uids: List[str], expiration: Optional[DHTExpiration], future: SharedFuture):
         loop = asyncio.get_event_loop()
+        expiration = expiration or get_dht_time()
+
         lookup_futures = [asyncio.run_coroutine_threadsafe(
-            self.server.get(self.make_key('expert', uid)), loop) for uid in uids]
-        current_time = datetime.datetime.now()
+            self.node.get(self.make_key('expert', uid), expiration), loop) for uid in uids]
 
-        experts = [None] * len(uids)
+        experts: List[Optional[RemoteExpert]] = [None] * len(uids)
         for i, (uid, lookup) in enumerate(zip(uids, lookup_futures)):
-            if lookup.result() is not None:
-                (host, port), timestamp = PickleSerializer.loads(lookup.result())
-                if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
-                    experts[i] = RemoteExpert(uid=uid, host=host, port=port)
+            maybe_result, maybe_expiration = lookup.result()
+            if maybe_expiration is not None:  # if we found a value
+                experts[i] = RemoteExpert(uid=uid, host=maybe_result[0], port=maybe_result[1])
 
         future.set_result(experts)
 
-    def declare_experts(self, uids: List[str], addr, port, wait_timeout=0):
+    def declare_experts(self, uids: List[str], addr, port, wait=True, timeout=None) -> Optional[List[bool]]:
         """
         Make experts available to DHT; update timestamps if already available
         :param uids: a list of expert ids to update
         :param addr: hostname that can be used to call this expert
         :param port: port that can be used to call this expert
-        :param wait_timeout: if wait_timeout > 0, waits for the procedure to finish
+        :param wait: if True, awaits for declaration to finish, otherwise runs in background
+        :param timeout: waits for the procedure to finish, None means wait indeninitely
+        :returns: if wait, returns a list of booleans, (True = store succeeded, False = store rejected)
         """
-        done_event = mp.Event() if wait_timeout else None
-        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, done_event=done_event)))
-        if done_event is not None:
-            done_event.wait(wait_timeout)
+        future, _future = SharedFuture.make_pair() if wait else (None, None)
+        self.pipe.send(('_declare_experts', [], dict(uids=list(uids), addr=addr, port=port, future=_future)))
+        if wait:
+            return future.result(timeout)
 
-    def _declare_experts(self, uids: List[str], addr: str, port: int, done_event: Optional[mp.Event]):
+    def _declare_experts(self, uids: List[str], addr: str, port: int, future: Optional[SharedFuture]):
+        assert self.node is not None, "This method should only be accessed from inside .run method"
         loop = asyncio.get_event_loop()
-        timestamp = datetime.datetime.now()
-        expert_metadata = PickleSerializer.dumps(((addr, port), timestamp))
-        prefix_metadata = PickleSerializer.dumps(timestamp)
-
+        expiration_time = get_dht_time() + self.EXPIRATION
         unique_prefixes = set()
+        coroutines = []
 
         for uid in uids:
-            asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('expert', uid), expert_metadata), loop)
+            coroutines.append(asyncio.run_coroutine_threadsafe(
+                self.node.store(self.make_key('expert', uid), value=(addr, port),
+                                expiration_time=expiration_time),
+                loop))
             uid_parts = uid.split(self.UID_DELIMETER)
             unique_prefixes.update([self.UID_DELIMETER.join(uid_parts[:i + 1]) for i in range(len(uid_parts))])
 
         for prefix in unique_prefixes:
-            asyncio.run_coroutine_threadsafe(self.server.set(self.make_key('prefix', prefix), prefix_metadata), loop)
+            coroutines.append(asyncio.run_coroutine_threadsafe(
+                self.node.store(self.make_key('prefix', prefix), True, expiration_time), loop))
 
-        if done_event is not None:
-            done_event.set()
+        if future is not None:
+            future.set_result([coro.result() for coro in coroutines])  # wait for all coroutings to finish
 
-    def first_k_active(self, prefixes: List[str], k: int, heartbeat_expiration=HEARTBEAT_EXPIRATION, max_prefetch=None):
+    def first_k_active(self, prefixes: List[str], k: int, max_prefetch=None):
         """
         Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search
         :param prefixes: a list of uid prefixes ordered from highest to lowest priority
         :param k: return at most *this many* active prefixes
-        :param heartbeat_expiration: consider expert active if his last heartbeat was sent at most this many seconds ago
         :param max_prefetch: pre-dispatch up to *this many* asynchronous expert requests, defaults to pre-dispatch = k
         :returns: a list of at most :k: prefixes that have at least one active expert each;
         """
+        assert isinstance(prefixes, (list, tuple)), "please provide a list/tuple of prefixes as the first argument"
         future, _future = SharedFuture.make_pair()
-        self.pipe.send(('_first_k_active', [], dict(prefixes=prefixes, k=k, heartbeat_expiration=heartbeat_expiration,
-                                                    max_prefetch=max_prefetch or k, future=_future)))
+        self.pipe.send(('_first_k_active', [],
+                        dict(prefixes=prefixes, k=k, max_prefetch=max_prefetch or k, future=_future)))
         return future.result()
 
-    def _first_k_active(self, prefixes: List[str], k, heartbeat_expiration, max_prefetch, future: SharedFuture):
+    def _first_k_active(self, prefixes: List[str], k: int, max_prefetch: Optional[int], future: SharedFuture):
+        assert self.node is not None, "This method should only be accessed from inside .run method"
+        max_prefetch = max_prefetch or len(prefixes)
         loop = asyncio.get_event_loop()
-        lookup_prefetch = [asyncio.run_coroutine_threadsafe(
-            self.server.get(self.make_key('prefix', prefix)), loop) for prefix in prefixes[:max_prefetch]]
-        current_time = datetime.datetime.now()
-
+        lookup_prefetch = [asyncio.run_coroutine_threadsafe(self.node.get(self.make_key('prefix', prefix)), loop)
+                           for prefix in prefixes[:max_prefetch]]
         active_prefixes = []
 
         for i, prefix in enumerate(prefixes):
-            lookup = lookup_prefetch[i]
+            _, maybe_expiration = lookup_prefetch[i].result()
 
-            if lookup.result() is not None:
-                timestamp = PickleSerializer.loads(lookup.result())
-                if (current_time - timestamp).total_seconds() <= heartbeat_expiration:
-                    active_prefixes.append(prefix)
-                    if len(active_prefixes) >= k:
-                        future.set_result(active_prefixes)
-                        return
+            if maybe_expiration is not None:
+                active_prefixes.append(prefix)
+                if len(active_prefixes) >= k:
+                    future.set_result(active_prefixes)
+                    for task in lookup_prefetch[i:]:
+                        task.cancel()
+                    return
 
             # pre-dispatch the next request in line
             if len(lookup_prefetch) < len(prefixes):
                 lookup_prefetch.append(
-                    asyncio.run_coroutine_threadsafe(self.server.get(
-                        self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop)
-                )
+                    asyncio.run_coroutine_threadsafe(
+                        self.node.get(self.make_key('prefix', prefixes[len(lookup_prefetch)])), loop))
 
         # could not find enough active prefixes; return what we can
         future.set_result(active_prefixes)

+ 205 - 0
hivemind/dht/node.py

@@ -0,0 +1,205 @@
+import asyncio
+from collections import OrderedDict
+from functools import partial
+from typing import Optional, Tuple, List, Dict
+from warnings import warn
+
+from .protocol import KademliaProtocol
+from .routing import DHTID, DHTValue, DHTExpiration, DHTKey, get_dht_time
+from .search import traverse_dht
+from ..utils import find_open_port, Endpoint, Hostname, Port, LOCALHOST
+
+
+class DHTNode:
+    """
+    A low-level class that represents a DHT participant.
+    Each DHTNode has an identifier, a local storage and access too other nodes via KademliaProtocol.
+
+    :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
+    :param port: port to which this DHTNode will listen, by default find some open port
+    :param initial_peers: connects to these peers to populate routing table, defaults to no peers
+    :param bucket_size: (k) - max number of nodes in one k-bucket. Trying to add {k+1}st node will cause a bucket to
+      either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
+      Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
+    :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
+    :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
+    :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
+    :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
+    :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+    :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
+
+    :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
+     For example, an expert alive timestamp that emitted by the Server responsible for that expert.
+     Such metadata does not require maintenance such as ensuring at least k hosts have it or (de)serialization in case
+     of node shutdown. Instead, DHTNode is designed to reduce the latency of looking up such data.
+
+    Every (key, value) pair in this DHT has expiration_time - float number computed as get_dht_time(), default: UnixTime
+    Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
+
+    Formally, DHTNode follows this contract:
+      - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
+       unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
+      - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
+       if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+    """
+
+    def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
+                 bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
+                 wait_timeout: float = 5, staleness_timeout: Optional[float] = 600,
+                 bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
+                 interface: Hostname = '0.0.0.0'):
+        self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
+        self.port = port = port if port is not None else find_open_port()
+        self.num_replicas = num_replicas if num_replicas is not None else bucket_size
+        self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
+        self.staleness_timeout = staleness_timeout
+
+        # create kademlia protocol and make it listen to a port
+        loop = asyncio.get_event_loop()
+        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout)
+        listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
+        self.transport: asyncio.Transport = listener[0]
+        self.protocol: KademliaProtocol = listener[1]
+
+        if initial_peers:
+            # bootstrap part 1: ping initial_peers, add each other to the routing table
+            bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
+            began_bootstrap_time = get_dht_time()
+            ping_tasks = map(self.protocol.call_ping, initial_peers)
+            finished_tasks, remaining_tasks = loop.run_until_complete(
+                asyncio.wait(ping_tasks, timeout=wait_timeout, return_when=asyncio.FIRST_COMPLETED))
+            time_to_first_response = get_dht_time() - began_bootstrap_time
+            # bootstrap part 2: gather all peers who responded within bootstrap_timeout, but at least one peer
+            if remaining_tasks:
+                finished_in_time, stragglers = loop.run_until_complete(
+                    asyncio.wait(remaining_tasks, timeout=bootstrap_timeout - time_to_first_response))
+                for straggler in stragglers:
+                    straggler.cancel()
+                finished_tasks |= finished_in_time
+
+            peer_ids = [task.result() for task in finished_tasks if task.result() is not None]
+            if len(peer_ids) == 0 and len(initial_peers) != 0:
+                warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
+
+            # bootstrap part 3: run beam search for my node id to add my own nearest neighbors to the routing table
+            # ... and maybe receive some values that we are meant to store (see protocol.update_routing_table)
+            loop.run_until_complete(self.find_nearest_nodes(query_id=self.node_id))
+
+    async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
+                                 beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
+        """
+        Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
+        :note: this is a thin wrapper over dht.search.beam_search, look there for more details
+        :returns: an ordered dictionary of [peer DHTID -> network Endpoint], ordered from nearest to farthest neighbor
+        """
+        k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
+        beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
+        node_to_addr = dict(
+            self.protocol.routing_table.get_nearest_neighbors(query_id, beam_size, exclude=self.node_id))
+
+        async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
+            peers: Dict[DHTID, Endpoint] = await self.protocol.call_find_node(node_to_addr[node], query_id)
+            node_to_addr.update(peers)
+            return list(peers.keys()), False  # False means "do not interrupt beam search"
+
+        nearest_nodes, visited_nodes = await traverse_dht(
+            query_id=query_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
+            get_neighbors=get_neighbors, visited_nodes=(self.node_id,))
+
+        if not exclude_self:
+            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query_id.xor_distance)[:k_nearest]
+            node_to_addr[self.node_id] = (LOCALHOST, self.port)
+
+        return OrderedDict((node, node_to_addr[node]) for node in nearest_nodes)
+
+    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+        """
+        Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
+        Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
+        :return: True if store succeeds, False if it fails (due to no response or newer value)
+        """
+        key_id = DHTID.generate(key)
+        nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
+        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, key_id, value, expiration_time))
+                 for endpoint in nearest_node_to_addr.values()]
+        done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
+        return any(done)
+
+    async def get(self, key: DHTKey, sufficient_expiration_time: Optional[DHTExpiration] = None,
+                  beam_size: Optional[int] = None) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
+        """
+        :param key: traverse the DHT and find the value for this key (or return None if it does not exist)
+        :param sufficient_expiration_time: if the search finds a value that expires after this time,
+            default = time of call, find any value that did not expire by the time of call
+            If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
+        :returns: value and its expiration time. If nothing is found , returns (None, None).
+        :note: in order to check if get returned a value, please check (expiration_time is None)
+        """
+        key_id = DHTID.generate(key)
+        sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
+        beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
+        latest_value, latest_expiration, latest_node_id = None, -float('inf'), None
+        node_to_addr, nodes_checked_for_value = dict(), set()
+        should_cache = False  # True if found value in DHT that is newer than local value
+
+        # Option A: value can be stored in our local cache
+        maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
+        if maybe_expiration is None:
+            maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
+        if maybe_expiration is not None and maybe_expiration > latest_expiration:
+            latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
+            # TODO(jheuristic) we may want to run background beam search to update our cache
+        nodes_checked_for_value.add(self.node_id)
+
+        # Option B: go beam search the DHT
+        if latest_expiration < sufficient_expiration_time:
+            node_to_addr.update(self.protocol.routing_table.get_nearest_neighbors(
+                key_id, self.protocol.bucket_size, exclude=self.node_id))
+
+            async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
+                nonlocal latest_value, latest_expiration, node_to_addr, nodes_checked_for_value
+                maybe_value, maybe_expiration, peers = await self.protocol.call_find_value(node_to_addr[node], key_id)
+                node_to_addr.update(peers)
+                nodes_checked_for_value.add(node)
+                if maybe_expiration is not None and maybe_expiration > latest_expiration:
+                    latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
+                should_interrupt = (latest_expiration >= sufficient_expiration_time)
+                return list(peers.keys()), should_interrupt
+
+            nearest_nodes, visited_nodes = await traverse_dht(
+                query_id=key_id, initial_nodes=list(node_to_addr), k_nearest=beam_size, beam_size=beam_size,
+                get_neighbors=get_neighbors, visited_nodes=nodes_checked_for_value)
+            # normally, by this point we will have found a sufficiently recent value in one of get_neighbors calls
+            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it later
+
+        # Option C: didn't find good-enough value in beam search, make a last-ditch effort to find it in unvisited nodes
+        if latest_expiration < sufficient_expiration_time:
+            nearest_unvisited = [node_id for node_id in nearest_nodes if node_id not in nodes_checked_for_value]
+            tasks = [self.protocol.call_find_value(node_to_addr[node_id], key_id) for node_id in nearest_unvisited]
+            pending_tasks = set(tasks)
+            for task in asyncio.as_completed(tasks):
+                pending_tasks.remove(task)
+                maybe_value, maybe_expiration, _ = await task
+                if maybe_expiration is not None and maybe_expiration > latest_expiration:
+                    latest_value, latest_expiration = maybe_value, maybe_expiration
+                    if latest_expiration >= sufficient_expiration_time:
+                        break
+            for task in pending_tasks:
+                task.close()
+            should_cache = latest_expiration >= sufficient_expiration_time  # if we found a newer value, cache it later
+
+        # step 4: we have not found entry with sufficient_expiration_time, but we may have found *something* older
+        if should_cache and self.cache_locally:
+            self.protocol.cache.store(key_id, latest_value, latest_expiration)
+        if should_cache and self.cache_nearest:
+            num_cached_nodes = 0
+            for node_id in nearest_nodes:
+                if node_id == latest_node_id:
+                    continue
+                asyncio.create_task(self.protocol.call_store(
+                    node_to_addr[node_id], key_id, latest_value, latest_expiration, in_cache=True))
+                num_cached_nodes += 1
+                if num_cached_nodes >= self.cache_nearest:
+                    break
+
+        return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)

+ 196 - 0
hivemind/dht/protocol.py

@@ -0,0 +1,196 @@
+import asyncio
+import heapq
+from typing import Optional, List, Tuple, Dict
+from rpcudp.protocol import RPCProtocol
+
+from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
+from ..utils import Endpoint
+
+
+class KademliaProtocol(RPCProtocol):
+    """
+    A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
+    As a side-effect, KademliaProtocol also maintains a routing table as described in
+    https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
+
+    See DHTNode (node.py) for a more detailed description.
+
+    :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
+     for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
+     Only the call_* methods are meant to be called publicly, e.g. from DHTNode
+     Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
+    """
+
+    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int,
+                 wait_timeout: float, cache_size: Optional[int] = None):
+        super().__init__(wait_timeout)
+        self.node_id, self.bucket_size = node_id, bucket_size
+        self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
+        self.storage = LocalStorage()
+        self.cache = LocalStorage(maxsize=cache_size)
+
+    def rpc_ping(self, sender: Endpoint, sender_id_bytes: BinaryDHTID) -> BinaryDHTID:
+        """ Some dht node wants us to add it to our routing table. """
+        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
+        return bytes(self.node_id)
+
+    async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
+        """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
+        responded, response = await self.ping(recipient, bytes(self.node_id))
+        recipient_node_id = DHTID.from_bytes(response) if responded else None
+        asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
+        return recipient_node_id
+
+    def rpc_store(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID,
+                  value: DHTValue, expiration_time: DHTExpiration, in_cache: bool) -> Tuple[bool, BinaryDHTID]:
+        """ Some node wants us to store this (key, value) pair """
+        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
+        if in_cache:
+            store_accepted = self.cache.store(DHTID.from_bytes(key_bytes), value, expiration_time)
+        else:
+            store_accepted = self.storage.store(DHTID.from_bytes(key_bytes), value, expiration_time)
+        return store_accepted, bytes(self.node_id)
+
+    async def call_store(self, recipient: Endpoint, key: DHTID, value: DHTValue,
+                         expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
+        """
+        Ask a recipient to store (key, value) pair until expiration time or update their older value
+        :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
+        """
+        responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
+                                               value, expiration_time, in_cache)
+        if responded:
+            store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
+            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
+            return store_accepted
+        return None
+
+    def rpc_find_node(self, sender: Endpoint, sender_id_bytes: BinaryDHTID,
+                      query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        """
+        Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
+        :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
+         also returns our own node id for routing table maintenance
+        """
+        query_id, sender_id = DHTID.from_bytes(query_id_bytes), DHTID.from_bytes(sender_id_bytes)
+        asyncio.ensure_future(self.update_routing_table(sender_id, sender))
+        peer_ids_and_addr = self.routing_table.get_nearest_neighbors(query_id, k=self.bucket_size, exclude=sender_id)
+        return [(bytes(peer_id), peer_addr) for peer_id, peer_addr in peer_ids_and_addr], bytes(self.node_id)
+
+    async def call_find_node(self, recipient: Endpoint, query_id: DHTID) -> Dict[DHTID, Endpoint]:
+        """
+        Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
+         it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
+        :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
+        """
+        responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
+        if responded:
+            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
+            # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
+            recipient_node_id = DHTID.from_bytes(response[1])
+            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
+            return peers
+        return {}
+
+    def rpc_find_value(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID) -> \
+            Tuple[Optional[DHTValue], Optional[DHTExpiration], List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        """
+        Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
+         Either way, return :bucket_size: nearest neighbors to that node.
+        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
+        :returns: (value or None if we have no value, nearest neighbors, our own dht id)
+        """
+        maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
+        cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
+        if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
+            maybe_value, maybe_expiration = cached_value, cached_expiration
+        nearest_neighbors, my_id = self.rpc_find_node(sender, sender_id_bytes, key_bytes)
+        return maybe_value, maybe_expiration, nearest_neighbors, my_id
+
+    async def call_find_value(self, recipient: Endpoint, key: DHTID) -> \
+            Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
+        """
+        Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
+        :returns: (optional value, optional expiration time, and neighbors)
+         value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
+         expiration time: expiration time of the returned value, None if no value was found
+         neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
+        Note: if no response, returns None, None, {}
+        """
+        responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
+        if responded:
+            (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
+            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
+            asyncio.ensure_future(self.update_routing_table(recipient_id, recipient, responded=responded))
+            return value, expiration_time, peers
+        return None, None, {}
+
+    async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
+        """
+        This method is called on every incoming AND outgoing request to update the routing table
+        :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
+        :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
+        :param responded: for outgoing requests, this indicated whether recipient responded or not.
+          For incoming requests, this should always be True
+        """
+        if responded:  # incoming request or outgoing request with response
+            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
+            if maybe_node_to_ping is not None:
+                # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
+                # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
+                await self.call_ping(maybe_node_to_ping[1])  # [1]-th element is that node's endpoint
+
+        else:  # outgoing request and peer did not respond
+            if node_id is not None and node_id in self.routing_table:
+                del self.routing_table[node_id]
+
+    def _accept_response(self, msg_id, data, address):
+        """ Override for RPCProtocol._accept_response to handle cancelled tasks """
+        future, timeout = self._outstanding[msg_id]
+        if future.cancelled():
+            timeout.cancel()
+            del self._outstanding[msg_id]
+        else:
+            super()._accept_response(msg_id, data, address)
+
+
+
+class LocalStorage:
+    def __init__(self, maxsize: Optional[int] = None):
+        self.cache_size = maxsize or float("inf")
+        self.data = dict()
+        self.expiration_heap = []
+        self.key_to_heap = dict()
+
+    def remove_outdated(self):
+        while self.expiration_heap and (self.expiration_heap[0][0] < get_dht_time()
+                                        or len(self.expiration_heap) > self.cache_size):
+            heap_entry = heapq.heappop(self.expiration_heap)
+            key = heap_entry[1]
+            if self.key_to_heap[key] == heap_entry:
+                del self.data[key], self.key_to_heap[key]
+
+    def store(self, key: DHTID, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+        """
+        Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
+        :returns: True if new value was stored, False it was rejected (current value is newer)
+        """
+        if expiration_time < get_dht_time():
+            return False
+        self.key_to_heap[key] = (expiration_time, key)
+        heapq.heappush(self.expiration_heap, (expiration_time, key))
+        if key in self.data:
+            if self.data[key][1] < expiration_time:
+                self.data[key] = (value, expiration_time)
+                return True
+            return False
+        self.data[key] = (value, expiration_time)
+        self.remove_outdated()
+        return True
+
+    def get(self, key: DHTID) -> (Optional[DHTValue], Optional[DHTExpiration]):
+        """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
+        self.remove_outdated()
+        if key in self.data:
+            return self.data[key]
+        return None, None

+ 266 - 0
hivemind/dht/routing.py

@@ -0,0 +1,266 @@
+from __future__ import annotations
+
+import hashlib
+import os
+import random
+
+import time
+import heapq
+from collections.abc import Iterable
+from itertools import chain
+from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence, Iterator
+
+from ..utils import Endpoint, PickleSerializer
+
+
+class RoutingTable:
+    """
+    A data structure that contains DHT peers bucketed according to their distance to node_id
+    :param node_id: node id used to measure distance
+    :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
+    :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
+    :note: you can find a more detailed docstring for Node class, see node.py
+    :note: kademlia paper refers to https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
+    """
+
+    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
+        self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
+        self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
+
+    def get_bucket_index(self, node_id: DHTID) -> int:
+        """ Get the index of the bucket that the given node would fall into. """
+        # TODO use binsearch aka from bisect import bisect.
+        for index, bucket in enumerate(self.buckets):
+            if bucket.lower <= node_id < bucket.upper:
+                return index
+        raise ValueError(f"Failed to get bucket for node_id={node_id}, this should not be possible.")
+
+    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
+        """
+        Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
+        :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
+        :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
+          If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
+          the start of the table or remove that node and replace it with
+        """
+        bucket_index = self.get_bucket_index(node_id)
+        bucket = self.buckets[bucket_index]
+
+        if bucket.add_or_update_node(node_id, addr):
+            return  # this will succeed unless the bucket is full
+
+        # Per section 4.2 of paper, split if the bucket has node's own id in its range
+        # or if bucket depth is not congruent to 0 mod $b$
+        if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
+            self.split_bucket(bucket_index)
+            return self.add_or_update_node(node_id, addr)
+
+        # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
+        return bucket.request_ping_node()
+
+    def split_bucket(self, index: int) -> None:
+        """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
+        first, second = self.buckets[index].split()
+        self.buckets[index] = first
+        self.buckets.insert(index + 1, second)
+
+    def __getitem__(self, node_id: DHTID) -> Endpoint:
+        return self.buckets[self.get_bucket_index(node_id)][node_id]
+
+    def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
+        raise NotImplementedError("KBucket doesn't support direct item assignment. Use KBucket.try_add_node instead")
+
+    def __contains__(self, node_id: DHTID) -> bool:
+        return node_id in self.buckets[self.get_bucket_index(node_id)]
+
+    def __delitem__(self, node_id: DHTID):
+        node_bucket = self.buckets[self.get_bucket_index(node_id)]
+        del node_bucket[node_id]
+
+    def get_nearest_neighbors(
+            self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
+        """
+        Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
+        :param query_id: find neighbors of this node
+        :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
+        :param exclude: if True, results will not contain query_node_id even if it is in table
+        :returns: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
+
+        :note: this is a semi-exhaustive search of nodes that takes O(n * log k) time.
+            One can implement a more efficient knn search using a binary skip-tree in some
+            more elegant language such as c++ / cython / numba.
+            Here's a sketch
+
+            Preparation: construct a non-regular binary tree of depth (2 * DHTID.HASH_NBYTES)
+             Each leaf corresponds to a binary DHTID with '0' for every left turn and '1' for right turn
+             Each non-leaf corresponds to a certain prefix, e.g. 0010110???...???
+             If there are no nodes under some prefix xxxY???..., the corresponding node xxx????...
+             will only have one child.
+            Add(node):
+             Traverse down a tree, on i-th level go left if node_i == 0, right if node_i == 1
+             If the corresponding node is missing, simply create it
+            Search(query, k):
+             Traverse the tree with a depth-first search, on i-th level go left if query_i == 0, else right
+             If the corresponding node is missing, go the other way. Proceed until you found a leaf.
+             This leaf is your nearest neighbor. Now add more neighbors by considering alternative paths
+             bottom-up, i.e. if your nearest neighbor is 01011, first try 01010, then 0100x, then 011xx, ...
+
+            This results in O(num_nodes * bit_length) complexity for add and search
+            Better yet: use binary tree with skips for O(num_nodes * log(num_nodes))
+        """
+        all_nodes: Iterator[Tuple[DHTID, Endpoint]] = chain(*self.buckets)  # uses KBucket.__iter__
+        nearest_nodes_with_addr: List[Tuple[DHTID, Endpoint]] = heapq.nsmallest(
+            k + int(exclude is not None), all_nodes, lambda id_and_endpoint: query_id.xor_distance(id_and_endpoint[0]))
+        if exclude is not None:
+            for i, (node_i, addr_i) in enumerate(list(nearest_nodes_with_addr)):
+                if node_i == exclude:
+                    del nearest_nodes_with_addr[i]
+                    break
+            if len(nearest_nodes_with_addr) > k:
+                nearest_nodes_with_addr.pop()  # if excluded element is not among (k + 1) nearest, simply crop to k
+        return nearest_nodes_with_addr
+
+    def __repr__(self):
+        bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
+        return f"{self.__class__.__name__}(node_id={self.node_id}, bucket_size={self.bucket_size}," \
+               f" modulo={self.depth_modulo},\nbuckets=[\n{bucket_info})"
+
+
+class KBucket:
+    """
+    A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
+    Maps DHT node ids to their endpoints (hostname, addr)
+    """
+    def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
+        assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
+        self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
+        self.nodes_to_addr: Dict[DHTID, Endpoint] = {}
+        self.replacement_nodes: Dict[DHTID, Endpoint] = {}
+        self.nodes_requested_for_ping: Set[DHTID] = set()
+        self.last_updated = get_dht_time()
+
+    def has_in_range(self, node_id: DHTID):
+        """ Check if node_id is between this bucket's lower and upper bounds """
+        return self.lower <= node_id < self.upper
+
+    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> bool:
+        """
+        Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
+        If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
+        :param node_id: dht node identifier that should be added or moved to the front of bucket
+        :param addr: a pair of (hostname, port) associated with that node id
+        :note: this function has a side-effect of resetting KBucket.last_updated time
+        """
+        if node_id in self.nodes_requested_for_ping:
+            self.nodes_requested_for_ping.remove(node_id)
+        self.last_updated = get_dht_time()
+        if node_id in self.nodes_to_addr:
+            del self.nodes_to_addr[node_id]
+            self.nodes_to_addr[node_id] = addr
+        elif len(self) < self.size:
+            self.nodes_to_addr[node_id] = addr
+        else:
+            if node_id in self.replacement_nodes:
+                del self.replacement_nodes[node_id]
+            self.replacement_nodes[node_id] = addr
+            return False
+        return True
+
+    def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
+        """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
+        for uid, endpoint in self.nodes_to_addr.items():
+            if uid not in self.nodes_requested_for_ping:
+                return uid, endpoint
+
+    def __getitem__(self, node_id: DHTID) -> Endpoint:
+        return self.nodes_to_addr[node_id] if node_id in self.nodes_to_addr else self.replacement_nodes[node_id]
+
+    def __delitem__(self, node_id: DHTID):
+        if not (node_id in self.nodes_to_addr or node_id in self.replacement_nodes):
+            raise KeyError(f"KBucket does not contain node id={node_id}.")
+
+        if node_id in self.replacement_nodes:
+            del self.replacement_nodes[node_id]
+
+        if node_id in self.nodes_to_addr:
+            del self.nodes_to_addr[node_id]
+
+            if self.replacement_nodes:
+                newnode_id, newnode = self.replacement_nodes.popitem()
+                self.nodes_to_addr[newnode_id] = newnode
+
+    def __len__(self):
+        return len(self.nodes_to_addr)
+
+    def __iter__(self):
+        return iter(self.nodes_to_addr.items())
+
+    def split(self) -> Tuple[KBucket, KBucket]:
+        """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
+        midpoint = (self.lower + self.upper) // 2
+        assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
+        left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
+        right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
+        for node_id, addr in chain(self.nodes_to_addr.items(), self.replacement_nodes.items()):
+            bucket = left if int(node_id) <= midpoint else right
+            bucket.add_or_update_node(node_id, addr)
+        return left, right
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({len(self.nodes_to_addr)} nodes" \
+               f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
+               f" lower={hex(self.lower)}, upper={hex(self.upper)})"
+
+
+class DHTID(int):
+    HASH_FUNC = hashlib.sha1
+    HASH_NBYTES = 20  # SHA1 produces a 20-byte (aka 160bit) number
+    RANGE = MIN, MAX = 0, 2 ** (HASH_NBYTES * 8)  # inclusive min, exclusive max
+
+    def __new__(cls, value: int):
+        assert cls.MIN <= value < cls.MAX, f"DHTID must be in [{cls.MIN}, {cls.MAX}) but got {value}"
+        return super().__new__(cls, value)
+
+    @classmethod
+    def generate(cls, source: Optional[Any] = None, nbits: int = 255):
+        """
+        Generates random uid based on SHA1
+        :param source: if provided, converts this value to bytes and uses it as input for hashing function;
+            by default, generates a random dhtid from :nbits: random bits
+        """
+        source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
+        source = PickleSerializer.dumps(source) if not isinstance(source, bytes) else source
+        raw_uid = cls.HASH_FUNC(source).digest()
+        return cls(int(raw_uid.hex(), 16))
+
+    def xor_distance(self, other: Union[DHTID, Sequence[DHTID]]) -> Union[int, List[int]]:
+        """
+        :param other: one or multiple DHTIDs. If given multiple DHTIDs as other, this function
+         will compute distance from self to each of DHTIDs in other.
+        :return: a number or a list of numbers whose binary representations equal bitwise xor between DHTIDs.
+        """
+        if isinstance(other, Iterable):
+            return list(map(self.xor_distance, other))  # TODO make some SIMD
+        return int(self) ^ int(other)
+
+    @classmethod
+    def longest_common_prefix_length(cls, *ids: DHTID) -> int:
+        ids_bits = [bin(uid)[2:].rjust(8 * cls.HASH_NBYTES, '0') for uid in ids]
+        return len(os.path.commonprefix(ids_bits))
+
+    def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
+        return super().to_bytes(length, byteorder, signed=signed)
+
+    @classmethod
+    def from_bytes(self, bytes, byteorder='big', *, signed=False) -> DHTID:
+        return DHTID(super().from_bytes(bytes, byteorder=byteorder, signed=signed))
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({hex(self)})"
+
+    def __bytes__(self):
+        return self.to_bytes()
+
+
+DHTKey, DHTValue, DHTExpiration, BinaryDHTID = Any, Any, float, bytes  # flavour types
+get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time

+ 62 - 0
hivemind/dht/search.py

@@ -0,0 +1,62 @@
+import heapq
+from typing import Collection, Callable, Tuple, List, Awaitable, Set
+from warnings import warn
+
+from .routing import DHTID
+
+
+async def traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], k_nearest: int, beam_size: int,
+                       get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
+                       visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
+    """
+    Asynchronous beam search over the DHT. Not meant to be called by the user, please use DHTNode.store/get instead.
+    Traverse the DHT graph using get_neighbors function, find up to k_nearest nodes according to DHTID.xor_distance.
+    Approximate time complexity: O(T * log T) where T = (path_to_true_nearest + beam_size) * mean_num_neighbors
+
+    :param query_id: search query, find k_nearest neighbors of this DHTID
+    :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, *maybe_some_peers]
+    :param k_nearest: find up to this many nearest neighbors. If there are less nodes in the DHT, return all nodes
+    :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
+        Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
+
+    :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria.
+        async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool
+        If should_continue is False, beam search will halt and return k_nearest of whatever it found by then.
+
+    :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest
+    :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes)
+    """
+    if beam_size < k_nearest:
+        warn(f"beam search: beam_size({beam_size}) is too small, beam search may fail to find k neighbors.")
+    visited_nodes = set(visited_nodes)  # note: copy visited_nodes because we will add more nodes to this collection.
+    initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
+    if not initial_nodes:
+        return [], visited_nodes
+
+    unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
+    heapq.heapify(unvisited_nodes)  # nearest-first heap of candidates, unlimited size
+
+    nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
+    heapq.heapify(nearest_nodes)  # farthest-first heap of size beam_size, used for early-stopping and to select results
+    while len(nearest_nodes) > beam_size:
+        heapq.heappop(nearest_nodes)
+
+    visited_nodes |= set(initial_nodes)
+    upper_bound = -nearest_nodes[0][0]  # distance to farthest element that is still in beam
+    was_interrupted = False  # will set to True if host triggered beam search to stop via get_neighbors
+
+    while (not was_interrupted) and len(unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound:
+        _, node_id = heapq.heappop(unvisited_nodes)  # note: this  --^ is the smallest element in heap (see heapq)
+        neighbors, was_interrupted = await get_neighbors(node_id)
+        neighbors = [node_id for node_id in neighbors if node_id not in visited_nodes]
+        visited_nodes.update(neighbors)
+
+        for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)):
+            if distance <= upper_bound or len(nearest_nodes) < beam_size:
+                heapq.heappush(unvisited_nodes, (distance, neighbor_id))
+
+                heapq_add_or_replace = heapq.heappush if len(nearest_nodes) < beam_size else heapq.heappushpop
+                heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
+                upper_bound = -nearest_nodes[0][0]  # distance to beam_size-th nearest element found so far
+
+    return [node_id for _, node_id in heapq.nlargest(k_nearest, nearest_nodes)], visited_nodes

+ 4 - 4
hivemind/server/__init__.py

@@ -2,11 +2,11 @@ import multiprocessing as mp
 import os
 import threading
 from socket import socket, AF_INET, SOCK_STREAM, SO_REUSEADDR, SOL_SOCKET, timeout
-from typing import Dict
+from typing import Dict, Optional
 
 from .connection_handler import handle_connection
 from .dht_handler import DHTHandlerThread
-from ..dht import DHTNode
+from ..dht import DHT
 from ..runtime import Runtime, ExpertBackend
 
 
@@ -20,7 +20,7 @@ class Server(threading.Thread):
      - publishes updates to expert status every :update_period: seconds
      - follows orders from HivemindController - if it exists
 
-    :type dht: DHTNode or None. Server with dht=None will NOT be visible from DHT,
+    :type dht: DHT or None. Server with dht=None will NOT be visible from DHT,
      but it will still support accessing experts directly with RemoteExpert(uid=UID, host=IPADDR, port=PORT).
     :param expert_backends: dict{expert uid (str) : ExpertBackend} for all expert hosted by this server.
     :param addr: server's dht address that determines how it can be accessed. Default is local connections only.
@@ -33,7 +33,7 @@ class Server(threading.Thread):
         is ready (see .ready below)
     """
 
-    def __init__(self, dht: DHTNode, expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
+    def __init__(self, dht: Optional[DHT], expert_backends: Dict[str, ExpertBackend], addr='127.0.0.1',
                  port: int = 8080, conn_handler_processes: int = 1, update_period: int = 30, start=False,
                  **kwargs):
         super().__init__()

+ 2 - 2
hivemind/server/dht_handler.py

@@ -1,11 +1,11 @@
 import threading
 import time
 
-from ..dht import DHTNode
+from ..dht import DHT
 
 
 class DHTHandlerThread(threading.Thread):
-    def __init__(self, experts, dht: DHTNode,
+    def __init__(self, experts, dht: DHT,
                  update_period: int = 5, addr: str = '127.0.0.1', port: int = 8080):
         super(DHTHandlerThread, self).__init__()
         self.port = port

+ 16 - 10
hivemind/utils/connection.py

@@ -1,7 +1,11 @@
-from contextlib import AbstractContextManager
-from socket import socket
+import socket
+from contextlib import AbstractContextManager, closing
 from typing import Tuple
 
+Hostname, Port = str, int  # flavour types
+Endpoint = Tuple[Hostname, Port]  # https://networkengineering.stackexchange.com/a/9435
+LOCALHOST = '127.0.0.1'
+
 
 class Connection(AbstractContextManager):
     header_size = 4  # number of characters in all headers
@@ -9,12 +13,12 @@ class Connection(AbstractContextManager):
 
     __slots__ = ('conn', 'addr')
 
-    def __init__(self, conn: socket, addr: Tuple[str, int]):
+    def __init__(self, conn: socket, addr: Endpoint):
         self.conn, self.addr = conn, addr
 
     @staticmethod
     def create(host: str, port: int):
-        sock = socket()
+        sock = socket.socket()
         addr = (host, port)
         sock.connect(addr)
         return Connection(sock, addr)
@@ -54,10 +58,12 @@ class Connection(AbstractContextManager):
         self.conn.close()
 
 
-def find_open_port():
+def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:
-        sock = socket()
-        sock.bind(('', 0))
-        return sock.getsockname()[1]
-    except:
-        raise ValueError("Could not find open port")
+        with closing(socket.socket(*params)) as sock:
+            sock.bind(('', 0))
+            sock.setsockopt(*opt)
+            return sock.getsockname()[1]
+    except Exception as e:
+        raise e

+ 1 - 1
requirements.txt

@@ -3,6 +3,6 @@ joblib>=0.13
 numpy>=1.17
 requests>=2.22.0
 tqdm
-kademlia>=2.2
+rpcudp>=4.0.0
 prefetch_generator>=1.0.1
 nose>=1.3.0

+ 104 - 0
tests/benchmark_dht.py

@@ -0,0 +1,104 @@
+import argparse
+import time
+import asyncio
+import multiprocessing as mp
+import random
+
+import hivemind
+from typing import List, Dict
+
+from hivemind import get_dht_time
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
+
+
+def run_benchmark_node(node_id, port, peers, ready: mp.Event, request_perod,
+                       expiration_time, wait_before_read, time_to_test, statistics: mp.Queue, dht_loaded: mp.Event):
+    if asyncio.get_event_loop().is_running():
+        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+    node = DHTNode(node_id, port, initial_peers=peers)
+    await_forever = hivemind.run_forever(asyncio.get_event_loop().run_forever)
+    ready.set()
+    dht_loaded.wait()
+    start = time.perf_counter()
+    while time.perf_counter() < start + time_to_test:
+        query_id = DHTID.generate()
+        store_value = random.randint(0, 256)
+
+        store_time = time.perf_counter()
+        success_store = asyncio.run_coroutine_threadsafe(
+            node.store(query_id, store_value, get_dht_time() + expiration_time), loop).result()
+        store_time = time.perf_counter() - store_time
+        if success_store:
+            time.sleep(wait_before_read)
+            get_time = time.perf_counter()
+            get_value, get_time_expiration = asyncio.run_coroutine_threadsafe(node.get(query_id), loop).result()
+            get_time = time.perf_counter() - get_time
+            success_get = (get_value == store_value)
+            statistics.put((success_store, store_time, success_get, get_time))
+        else:
+            statistics.put((success_store, store_time, None, None))
+    await_forever.result()  # process will exit only if event loop broke down
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--num_nodes', type=int, default=20, required=False)
+    parser.add_argument('--request_perod', type=float, default=2, required=False)
+    parser.add_argument('--expiration_time', type=float, default=10, required=False)
+    parser.add_argument('--wait_before_read', type=float, default=1, required=False)
+    parser.add_argument('--time_to_test', type=float, default=10, required=False)
+    args = parser.parse_args()
+
+    statistics = mp.Queue()
+    dht: Dict[Endpoint, DHTID] = {}
+    processes: List[mp.Process] = []
+
+    num_nodes = args.num_nodes
+    request_perod = args.request_perod
+    expiration_time = args.expiration_time
+    wait_before_read = args.wait_before_read
+    time_to_test = args.time_to_test
+
+    dht_loaded = mp.Event()
+    for i in range(num_nodes):
+        node_id = DHTID.generate()
+        port = hivemind.find_open_port()
+        peers = random.sample(dht.keys(), min(len(dht), 5))
+        ready = mp.Event()
+        proc = mp.Process(target=run_benchmark_node, args=(node_id, port, peers, ready, request_perod,
+                                                           expiration_time, wait_before_read, time_to_test, statistics,
+                                                           dht_loaded), daemon=True)
+        proc.start()
+        ready.wait()
+        processes.append(proc)
+        dht[(LOCALHOST, port)] = node_id
+    dht_loaded.set()
+    time.sleep(time_to_test)
+    success_store = 0
+    all_store = 0
+    time_store = 0
+    success_get = 0
+    all_get = 0
+    time_get = 0
+    while not statistics.empty():
+        success_store_i, store_time_i, success_get_i, get_time_i = statistics.get()
+        all_store += 1
+        time_store += store_time_i
+        if success_store_i:
+            success_store += 1
+            all_get += 1
+            success_get += 1 if success_get_i else 0
+            time_get += get_time_i
+    alive_nodes_count = 0
+    loop = asyncio.new_event_loop()
+    node = DHTNode(loop=loop)
+    for addr, port in dht:
+        if loop.run_until_complete(node.protocol.call_ping((addr, port))) is not None:
+            alive_nodes_count += 1
+    print("store success rate: ", success_store / all_store)
+    print("mean store time: ", time_store / all_store)
+    print("get success rate: ", success_get / all_get)
+    print("mean get time: ", time_get / all_get)
+    print("death rate: ", (num_nodes - alive_nodes_count) / num_nodes)

+ 276 - 0
tests/test_dht.py

@@ -0,0 +1,276 @@
+import time
+import asyncio
+import multiprocessing as mp
+import random
+import heapq
+import uuid
+from functools import partial
+from itertools import chain
+from typing import Optional
+import numpy as np
+
+import hivemind
+from typing import List, Dict
+
+from hivemind import get_dht_time
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol
+from hivemind.dht.protocol import LocalStorage
+
+
+def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
+                          ping: Optional[hivemind.Endpoint] = None):
+    loop = asyncio.new_event_loop()
+    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5)
+    listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
+    transport, protocol = loop.run_until_complete(listen)
+    print(f"Started peer id={protocol.node_id} port={port}", flush=True)
+
+    if ping is not None:
+        loop.run_until_complete(protocol.call_ping(ping))
+    started.set()
+    loop.run_forever()
+    print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
+
+
+def test_kademlia_protocol():
+    try:
+        # create the first peer
+        peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+        peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
+        peer1_proc.start(), peer1_started.wait()
+
+        # create another peer that connects to the first peer
+        peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+        peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
+                                kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True)
+        peer2_proc.start(), peer2_started.wait()
+
+        port = hivemind.find_open_port()
+        loop = asyncio.new_event_loop()
+        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5)
+        listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
+        transport, protocol = loop.run_until_complete(listen)
+        print(f"Self id={protocol.node_id} port={port}", flush=True)
+
+        assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id
+
+        key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+        assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration))
+
+        # peer 1 must know about peer 2
+        nodes_found = loop.run_until_complete(
+            protocol.call_find_node(('127.0.0.1', peer1_port), key))
+        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
+        assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \
+            f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}"
+
+        # peer 2 must know about peer 1
+        nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key))
+        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
+        assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \
+            f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}"
+
+        recv_value, recv_expiration, recv_peers = loop.run_until_complete(
+            protocol.call_find_value(('127.0.0.1', peer1_port), key))
+        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+                                                                      f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+        print(recv_peers, nodes_found)
+        assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node"
+        print("Kademlia test finished sucessfully!")
+
+    finally:
+        peer1_proc.terminate()
+        peer2_proc.terminate()
+
+
+def run_node(node_id, port, peers, status_pipe: mp.Pipe):
+    if asyncio.get_event_loop().is_running():
+        asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
+    asyncio.set_event_loop(asyncio.new_event_loop())
+    try:
+        node = DHTNode(node_id, port, initial_peers=peers)
+        status_pipe.send('STARTED')
+        while True:
+            asyncio.get_event_loop().run_forever()
+    except BaseException as e:
+        status_pipe.send(e)  # report exception to master
+        if not isinstance(e, OSError):
+            raise e
+
+
+def test_dht():
+    # create dht with 50 nodes + your 51-st node
+    dht: Dict[Endpoint, DHTID] = {}
+    processes: List[mp.Process] = []
+    port_fails, max_port_fails = 0, 10
+
+    while len(dht) < 50:
+        node_id = DHTID.generate()
+        peers = random.sample(dht.keys(), min(len(dht), 5))
+        port = hivemind.find_open_port()
+        pipe_recv, pipe_send = mp.Pipe(duplex=False)
+        proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True)
+        proc.start()
+
+        status = pipe_recv.recv()
+        if status == 'STARTED':
+            processes.append(proc)
+            dht[(LOCALHOST, port)] = node_id
+        else:
+            assert isinstance(status, BaseException)
+            proc.terminate()
+            if isinstance(status, OSError):  # port already in use. It just happens sometimes.
+                port_fails += 1
+                if port_fails > max_port_fails:
+                    raise OSError("Too many 'Address already in use' errors.")
+            else:
+                raise ValueError(f"Failed to create node due to an error {status}, see traceback above")
+
+    loop = asyncio.get_event_loop()
+    me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0)  # port=0 means os-specified port
+
+    # test 1: find self
+    nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1))
+    assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
+
+    # test 2: find others
+    for i in range(10):
+        ref_endpoint, query_id = random.choice(list(dht.items()))
+        nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1))
+        assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
+
+    # test 3: find neighbors to random nodes
+    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
+    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
+    all_node_ids = list(dht.values())
+
+    for i in range(100):
+        query_id = DHTID.generate()
+        k_nearest = random.randint(1, 20)
+        exclude_self = random.random() > 0.5
+        nearest = loop.run_until_complete(
+            me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
+        nearest_nodes = list(nearest)  # keys from ordered dict
+
+        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
+        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
+        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
+
+        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
+        if exclude_self and me.node_id in ref_nearest:
+            ref_nearest.remove(me.node_id)
+        if len(ref_nearest) > k_nearest:
+            ref_nearest.pop()
+
+        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
+        accuracy_denominator += 1
+
+        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
+        jaccard_denominator += k_nearest
+
+    accuracy = accuracy_numerator / accuracy_denominator
+    print("Top-1 accuracy:", accuracy)  # should be 98-100%
+    jaccard_index = jaccard_numerator / jaccard_denominator
+    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+
+    # test 4: find all nodes
+    nearest = loop.run_until_complete(
+        me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100))
+    assert len(nearest) == len(dht) + 1
+    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
+
+    # test 5: node without peers
+    other_node = hivemind.dht.node.DHTNode()
+    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
+    assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
+    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
+    assert len(nearest) == 0
+
+    # test 6 store and get value
+    true_time = get_dht_time() + 1200
+    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+    val, expiration_time = loop.run_until_complete(me.get("mykey"))
+    assert expiration_time == true_time, "Wrong time"
+    assert val == ["Value", 10], "Wrong value"
+
+    # terminate remaining processes
+    for proc in processes:
+        proc.terminate()
+
+
+def test_hivemind_dht():
+    peers = [hivemind.dht.DHT(start=True)]
+    for i in range(10):
+        neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(hivemind.DHT(*neighbors_i, start=True))
+
+    you: hivemind.dht.DHT = random.choice(peers)
+    theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
+
+    expert_uids = [str(uuid.uuid4()) for _ in range(110)]
+    batch_size = 10
+    for batch_start in range(0, len(expert_uids), batch_size):
+        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
+
+    found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
+    assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
+    assert all(res is None for res in found[-2:]), "Found non-existing experts"
+
+    that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
+    theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port)
+    you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
+    assert isinstance(you_found, hivemind.RemoteExpert)
+    assert you_found.host == 'that_host', you_found.port == that_guys_port
+
+    # test first_k_active
+    assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
+
+    some_permuted_experts = random.sample(expert_uids, k=32)
+    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts
+    assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1]
+    fake_and_real_experts = list(chain(*zip(
+        [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
+    assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9]
+
+    for peer in peers:
+        peer.shutdown()
+
+
+def test_store():
+    d = LocalStorage()
+    d.store("key", "val", get_dht_time() + 10)
+    assert d.get("key")[0] == "val", "Wrong value"
+    print("Test store passed")
+
+
+def test_get_expired():
+    d = LocalStorage()
+    d.store("key", "val", get_dht_time() + 1)
+    time.sleep(2)
+    assert d.get("key") == (None, None), "Expired value must be deleted"
+    print("Test get expired passed")
+
+
+def test_get_empty():
+    d = LocalStorage()
+    assert d.get("key") == (None, None), "Expired value must be deleted"
+    print("Test get expired passed")
+
+
+def test_change_expiration_time():
+    d = LocalStorage()
+    d.store("key", "val1", get_dht_time() + 2)
+    d.store("key", "val2", get_dht_time() + 200)
+    time.sleep(4)
+    assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table"
+    print("Test change expiration time passed")
+
+
+def test_maxsize_cache():
+    d = LocalStorage(maxsize=1)
+    d.store("key1", "val1", get_dht_time() + 1)
+    d.store("key2", "val2", get_dht_time() + 200)
+    assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept"
+    assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"

+ 1 - 1
tests/test_moe.py

@@ -66,7 +66,7 @@ def test_determinism():
 
 def test_compute_expert_scores():
     try:
-        dht = hivemind.DHTNode(port=hivemind.find_open_port(), start=True)
+        dht = hivemind.DHT(port=hivemind.find_open_port(), start=True)
         moe = hivemind.client.moe.RemoteMixtureOfExperts(
             dht=dht, in_features=1024, grid_size=(40,), k_best=4, k_min=1, timeout_after_k_min=1,
             uid_prefix='expert')

+ 114 - 0
tests/test_routing.py

@@ -0,0 +1,114 @@
+import random
+import heapq
+import operator
+from itertools import chain, zip_longest
+
+from hivemind.dht.routing import RoutingTable, DHTID
+from hivemind.utils.serializer import PickleSerializer
+
+
+def test_ids_basic():
+    # basic functionality tests
+    for i in range(100):
+        id1, id2 = DHTID.generate(), DHTID.generate()
+        assert DHTID.MIN <= id1 < DHTID.MAX and DHTID.MIN <= id2 <= DHTID.MAX
+        assert DHTID.xor_distance(id1, id1) == DHTID.xor_distance(id2, id2) == 0
+        assert DHTID.xor_distance(id1, id2) > 0 or (id1 == id2)
+        assert len(PickleSerializer.dumps(id1)) - len(PickleSerializer.dumps(int(id1))) < 40
+        assert DHTID.from_bytes(bytes(id1)) == id1 and DHTID.from_bytes(id2.to_bytes()) == id2
+
+
+def test_ids_depth():
+    for i in range(100):
+        ids = [random.randint(0, 4096) for i in range(random.randint(1, 256))]
+        ours = DHTID.longest_common_prefix_length(*map(DHTID, ids))
+
+        ids_bitstr = [
+            "".join(bin(bite)[2:].rjust(8, '0') for bite in uid.to_bytes(20, 'big'))
+            for uid in ids
+        ]
+        reference = len(shared_prefix(*ids_bitstr))
+        assert reference == ours, f"ours {ours} != reference {reference}, ids: {ids}"
+
+
+def test_routing_table_basic():
+    node_id = DHTID.generate()
+    routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
+
+    for phony_neighbor_port in random.sample(range(10000), 100):
+        phony_id = DHTID.generate()
+        routing_table.add_or_update_node(phony_id, ('localhost', phony_neighbor_port))
+        assert routing_table[phony_id] == ('localhost', phony_neighbor_port)
+
+    assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
+    for bucket in routing_table.buckets:
+        assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
+    assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
+
+
+def test_routing_table_parameters():
+    for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
+        (20,          5,      45,           65),
+        (50,          5,      35,           45),
+        (20,          10,     650,          800),
+        (20,          1,      7,            15),
+    ]:
+        node_id = DHTID.generate()
+        routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
+        for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
+            routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port))
+        for bucket in routing_table.buckets:
+            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
+        assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
+            f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
+
+
+def test_routing_table_search():
+    for table_size, lower_active, upper_active in [
+        (10, 10, 10), (10_000, 800, 1100)
+    ]:
+        node_id = DHTID.generate()
+        routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
+        num_added = 0
+        for phony_neighbor_port in random.sample(range(1_000_000), table_size):
+            num_added += routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port)) is None
+        num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
+    
+        all_active_neighbors = list(chain(
+            *(bucket.nodes_to_addr.keys() for bucket in routing_table.buckets)
+        ))
+        assert lower_active <= len(all_active_neighbors) <= upper_active
+        assert len(all_active_neighbors) == num_added
+        assert num_added + num_replacements == table_size
+    
+        # random queries
+        for i in range(500):
+            k = random.randint(1, 100)
+            query_id = DHTID.generate()
+            exclude = query_id if random.random() < 0.5 else None
+            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
+            reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
+            assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
+            assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
+    
+        # queries from table
+        for i in range(500):
+            k = random.randint(1, 100)
+            query_id = random.choice(all_active_neighbors)
+            our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
+            reference_knn = heapq.nsmallest(
+                k + 1, all_active_neighbors,
+                key=lambda uid: query_id.xor_distance(uid))
+            if query_id in reference_knn:
+                reference_knn.remove(query_id)
+            assert len(our_knn) == len(reference_knn)
+            assert all(query_id.xor_distance(our) == query_id.xor_distance(ref)
+                       for our, ref in zip_longest(our_knn, reference_knn))
+            assert routing_table.get_nearest_neighbors(query_id, k=k, exclude=None)[0][0] == query_id
+
+
+def shared_prefix(*strings: str):
+    for i in range(min(map(len, strings))):
+        if len(set(map(operator.itemgetter(i), strings))) != 1:
+            return strings[0][:i]
+    return min(strings, key=len)

+ 3 - 3
tests/test_utils/run_server.py

@@ -11,7 +11,7 @@ from .layers import name_to_block, name_to_input
 def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls='ffn', hidden_dim=1024,
                       num_handlers=None, expert_prefix='expert', expert_offset=0, max_batch_size=16384, device=None,
                       no_optimizer=False, no_dht=False, initial_peers=(), dht_port=None, root_port=None, verbose=True,
-                      UID_DELIMETER=hivemind.DHTNode.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
+                      UID_DELIMETER=hivemind.DHT.UID_DELIMETER, start=False, **kwargs) -> hivemind.Server:
     """
     Instantiate a server with several identical experts. See argparse comments below for details
     :param interface: 'localhost' for local connections only, '0.0.0.0' for ipv4 '::' for ipv6
@@ -45,7 +45,7 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
     if not no_dht:
         if not len(initial_peers):
             print("No initial peers provided. Starting additional dht as an initial peer.")
-            dht_root = hivemind.DHTNode(
+            dht_root = hivemind.DHT(
                 *initial_peers, port=root_port or hivemind.find_open_port(), start=True)
             print(f"Initializing DHT with port {dht_root.port}")
             initial_peers = (('localhost', dht_root.port),)
@@ -54,7 +54,7 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
             if root_port is not None:
                 print(f"Warning: root_port={root_port} will not be used since we already have peers.")
 
-        dht = hivemind.DHTNode(
+        dht = hivemind.DHT(
             *initial_peers, port=dht_port or hivemind.find_open_port(), start=True)
         if verbose:
             print(f"Running dht node on port {dht.port}")