Explorar o código

[part 2] grpc-based dht (#51)

* add option to share keys new peers that *should* be sharing them (improves data availability if a lot of peers join)

* * added option to NOT refresh table (new default)
* initial dht crawl is no longer blocking

* sphinx-friendly excape char

* pass cache params to kademliaprotocol

* typo

* rpc congestion

* rpc congestion

* rpc congestion

* add max concurrent rpc

* add max concurrent rpc

* fix bug in welcome protocol: previously dht peers always considered each other "new nodes" and sent EVERYTHING to EVERYONE on each rpc call. Now DHT nodes will only request store on .store call OR when a new peer knocks on their DHT

* increase to 128 concurrent rpc

* await dht traversal in bootstrap

* await dht traversal in bootstrap

* minor comment

* rename TensorProto -> TensorDescriptor to avoid name conflicts with protobuf

* add grpc requirements (break tests for now)

* add grpc requirements (break tests for now)

* add grpc requirements (break tests for now)

* minor bugfix: always add peer to nodes requested for ping

* [this breaks tests]
* implement DHTProtocol via gRPC
* DHTProtocol now stores bytes only (not enforcing msgpack)
* add grpc requirements

* update node.py to grpc DHTProtocol

* reminder to implement congestion

* reminder to implement congestion

* format

* temporary patch: adapt bulk RPCs to individual search

* temporary patch: adapt bulk RPCs to individual search

* remane KademliaProtocol -> DHTProtocol (rationale: no longer kademlia compliant)

* semicolon

* pep

* pep

* pep

* KademliaProtocol -> DHTProtocol

* update tests for new eviction policy (do not evict the same node twice)

* init aio in node constructor

* rename KademliaProtocol => DHTProtocol everywhere

* minor sphinx formatting fix

* partially update test_dht

* test: typo fix

* test: typo fix

* update test_dht for new dht interface

* compile grpc from master

* compile grpc from master

* compile grpc from master

* add umsgpack to requirements

* cache compiled grpcio

* ensure umsgpack version compatibility

* remove unused imports from dht folder

* update schemes

* update schemes

* review

* review

* ensure_future => create_task

Co-authored-by: xtinkt <ant.sinitsin@gmail.com>
justheuristic %!s(int64=5) %!d(string=hai) anos
pai
achega
8bded39d9b

+ 12 - 1
.circleci/config.yml

@@ -9,8 +9,19 @@ jobs:
     steps:
       - checkout
       - python/load-cache
+      - run:
+          command: |
+            if [[ $(pip show grpcio | grep Version) != *1.31* ]]; then
+              git clone https://github.com/grpc/grpc --recurse-submodules
+              cd grpc
+              sudo pip install -r requirements.txt
+              export GRPC_PYTHON_BUILD_WITH_CYTHON=1
+              sudo pip install .
+              cd -
+            fi
+          name: compile-grpc  # remove this command when v1.31 becomes available via pip install -r requirements.txt
+      - run: sudo pip install codecov pytest grpcio-tools
       - python/install-deps
-      - run: sudo pip install codecov pytest
       - python/save-cache
       - run:
           command: sudo python setup.py develop

BIN=BIN
docs/_static/dht.odp


BIN=BIN
docs/_static/dht.png


+ 1 - 1
docs/modules/dht.rst

@@ -20,7 +20,7 @@
 
 .. currentmodule:: hivemind.dht.protocol
 
-.. autoclass:: KademliaProtocol
+.. autoclass:: DHTProtocol
    :members:
    :member-order: bysource
 

+ 6 - 5
hivemind/dht/__init__.py

@@ -2,7 +2,7 @@
 This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
  * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
  * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
- * class KademliaProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
+ * class DHTProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
 
 The code in this module is a modified version of https://github.com/bmuller/kademlia
 Brian, if you're reading this: THANK YOU! you're awesome :)
@@ -10,13 +10,13 @@ Brian, if you're reading this: THANK YOU! you're awesome :)
 import asyncio
 import multiprocessing as mp
 import warnings
-from typing import Tuple, List, Optional
+from typing import List, Optional
 
 from .node import DHTNode, DHTID, DHTExpiration
 from .routing import get_dht_time
 
 from ..client import RemoteExpert
-from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_background
+from ..utils import SharedFuture, find_open_port, Endpoint, Port, run_in_background, LOCALHOST
 
 
 class DHT(mp.Process):
@@ -33,7 +33,7 @@ class DHT(mp.Process):
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
 
-    def __init__(self, *initial_peers: Tuple[Hostname, Port], port: Optional[Port] = None,
+    def __init__(self, *initial_peers: Endpoint, port: Optional[Port] = None,
                  start: bool, daemon: bool = True, **node_params):
         super().__init__()
         port = find_open_port() if port is None else port
@@ -52,7 +52,8 @@ class DHT(mp.Process):
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
 
-        self.node = DHTNode(initial_peers=list(self.initial_peers), port=self.port, **self.node_params)
+        self.node = loop.run_until_complete(DHTNode.create(
+            initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
         run_in_background(loop.run_forever)
         self.ready.set()
 

+ 55 - 0
hivemind/dht/dht.proto

@@ -0,0 +1,55 @@
+syntax = "proto3";
+
+// this protocol defines how Hivemind nodes form a distributed hash table.
+// For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
+
+service DHT {
+    // find out recipient's DHTID and possibly update its routing table
+    rpc rpc_ping(NodeInfo) returns (NodeInfo);
+
+    // request a node to store one or multiple data items (key - value - expiration)
+    rpc rpc_store(StoreRequest) returns (StoreResponse);
+
+    // for given keys, request values (if stored) or a list of peers that are likely to have them
+    rpc rpc_find(FindRequest) returns (FindResponse);
+}
+
+message NodeInfo {
+    // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
+    // if either node_id or port is absent, simply request recipient info (for client-only mode)
+    bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
+    int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+}
+
+message StoreRequest {
+    // three lists of the same length representing dht keys, dht values and expiration
+    repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
+    repeated bytes values = 2;        // binary-encoded value for i-th key
+    repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
+    repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
+    NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+}
+
+message StoreResponse {
+    repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
+    NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+}
+
+message FindRequest {
+    repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
+    NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+}
+
+message Peers {
+   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
+   repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
+   repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+}
+
+message FindResponse {
+    repeated bytes values = 1;        // value for i-th key, b'' means not found locally
+    repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
+    repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
+    NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+}
+

+ 133 - 90
hivemind/dht/node.py

@@ -1,107 +1,137 @@
+from __future__ import annotations
 import asyncio
 import random
 from collections import OrderedDict
-from functools import partial
 from typing import Optional, Tuple, List, Dict
 from warnings import warn
 
-from .protocol import KademliaProtocol
-from .routing import DHTID, DHTValue, DHTExpiration, DHTKey, get_dht_time
+from .protocol import DHTProtocol
+from .routing import DHTID, BinaryDHTValue, DHTExpiration, DHTKey, get_dht_time, DHTValue
 from .search import traverse_dht
-from ..utils import find_open_port, Endpoint, Hostname, Port, LOCALHOST
+from ..utils import Endpoint, LOCALHOST, MSGPackSerializer
 
 
 class DHTNode:
     """
-    A low-level class that represents a DHT participant.
-    Each DHTNode has an identifier, a local storage and access too other nodes via KademliaProtocol.
-
-    :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
-    :param port: port to which this DHTNode will listen, by default find some open port
-    :param initial_peers: connects to these peers to populate routing table, defaults to no peers
-    :param bucket_size: (k) - max number of nodes in one k-bucket. Trying to add {k+1}st node will cause a bucket to
-      either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
-      Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
-    :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
-    :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
-    :param max_concurrent_rpc: maximum number of outgoing RPC requests emitted by KademliaProtocol in parallel
-        Reduce this value if your RPC requests register no response despite the peer sending the response.
-    :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
-    :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
-        if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
-    :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
-    :param cache_locally: if True, caches all values (stored or found) in a node-local cache
-    :param cache_nearest: if above 0, whenever DHTNode finds a value, it will also store (cache) this value on this many
-        nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet.
-    :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
-    :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
-
-    :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
+    A low-level class that represents a DHT participant. Please see DHTNode.create for parameters
+    Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
+
+    :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
      For example, an expert alive timestamp that emitted by the Server responsible for that expert.
-     Such metadata does not require maintenance such as ensuring at least k hosts have it or (de)serialization in case
-     of node shutdown. Instead, DHTNode is designed to reduce the latency of looking up such data.
+     Such metadata does not require regular maintenance by peers, persistence on shutdown.
+     Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
 
-    Every (key, value) pair in this DHT has expiration_time - float number computed as get_dht_time(), default: UnixTime
-    Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
+    Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time(), UnixTime by default
+    DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
 
-    Formally, DHTNode follows this contract:
+    Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
 
-    - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
-      unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
-    - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
-      if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+    * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
+    * store - send several (key, value, expiration) pairs to the same peer (like Kademlia STORE, but in bulk)
+    * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
+        nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
+        This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
-    """
+    Formally, DHTNode follows the following contract:
 
-    def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
-                 bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
-                 max_concurrent_rpc: int = 128, wait_timeout: float = 5, staleness_timeout: Optional[float] = None,
-                 bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
-                 cache_size=None, interface: Hostname = '0.0.0.0'):
+    - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
+      IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
+    - when requested to store(key: value, expiration), a node must store (key => value) at until expiration time
+      or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
+      has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
+    - when requested to store(key: value, expiration, in_cache=True), stores (key => value) in a separate "cache".
+      Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
+
+    """
+    node_id: int; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; refresh_timeout: float
+    protocol: DHTProtocol
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+
+
+    @classmethod
+    async def create(
+            cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
+            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, max_requests: int = 0,
+            wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
+            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+        """
+        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
+        :param port: port to which this DHTNode will listen, by default find some open port
+        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
+        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
+          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
+          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
+        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
+        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
+        :param max_requests: maximum number of outgoing RPC requests emitted by DHTProtocol in parallel
+          Reduce this value if your RPC requests register no response despite the peer sending the response.
+        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
+        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
+          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
+        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
+          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
+        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
+        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
+          if False, this node will refuse any incoming request, effectively being only a "client"
+        :param listen_on: network interface for incoming RPCs, e.g. "0.0.0.0:1337" or "localhost:\*" or "[::]:7654"
+        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
+          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
+        :param kwargs: extra parameters used in grpc.aio.server
+        """
+        assert max_requests == 0, "TODO(jheuristic) implement congestion!"
+        self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
-        self.port = port = port if port is not None else find_open_port()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
         self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
-        self.staleness_timeout = staleness_timeout
+        self.refresh_timeout = refresh_timeout
+
+        self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
+                                                 cache_size, listen, listen_on, **kwargs)
+        self.port = self.protocol.port
 
-        # create kademlia protocol and make it listen to a port
-        loop = asyncio.get_event_loop()
-        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout,
-                                max_concurrent_rpc, num_replicas, cache_size)
-        listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
-        self.transport: asyncio.Transport = listener[0]
-        self.protocol: KademliaProtocol = listener[1]
 
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             ping_tasks = map(self.protocol.call_ping, initial_peers)
-            finished_ping_tasks, remaining_ping_tasks = loop.run_until_complete(
-                asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED))
+            finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
 
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
-            if remaining_ping_tasks:
-                finished_in_time, stragglers = loop.run_until_complete(
-                    asyncio.wait(remaining_ping_tasks, timeout=bootstrap_timeout - get_dht_time() + start_time))
+            if unfinished_pings:
+                finished_in_time, stragglers = await asyncio.wait(
+                    unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
                 for straggler in stragglers:
                     straggler.cancel()
-                finished_ping_tasks |= finished_in_time
+                finished_pings |= finished_in_time
 
-            if not finished_ping_tasks:
+            if not finished_pings:
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
-            loop.run_until_complete(asyncio.wait([loop.create_task(self.find_nearest_nodes(query_id=self.node_id)),
-                                                  asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
-                                                 return_when=asyncio.FIRST_COMPLETED))
+            await asyncio.wait([asyncio.create_task(self.find_nearest_nodes(key_id=self.node_id)),
+                                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
+                               return_when=asyncio.FIRST_COMPLETED)
+
+        if self.refresh_timeout is not None:
+            asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
+        return self
 
-        if self.staleness_timeout is not None:
-            loop.create_task(self._refresh_routing_table(period=self.staleness_timeout))
+    def __init__(self, *, _initialized_with_create=False):
+        """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
+        assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
+        super().__init__()
 
-    async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
+    async def shutdown(self, timeout=None):
+        """ Process existing requests, close all connections and stop the server """
+        await self.protocol.shutdown(timeout)
+
+    async def find_nearest_nodes(self, key_id: DHTID, k_nearest: Optional[int] = None,
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
         """
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
@@ -112,36 +142,41 @@ class DHTNode:
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         node_to_addr = dict(
-            self.protocol.routing_table.get_nearest_neighbors(query_id, beam_size, exclude=self.node_id))
+            self.protocol.routing_table.get_nearest_neighbors(key_id, beam_size, exclude=self.node_id))
+
+        async def get_neighbors(node_id: DHTID) -> Tuple[List[DHTID], bool]:
+            response = await self.protocol.call_find(node_to_addr[node_id], [key_id])
+            if not response or key_id not in response:
+                return [], False  # False means "do not interrupt search"
 
-        async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
-            peers: Dict[DHTID, Endpoint] = await self.protocol.call_find_node(node_to_addr[node], query_id)
+            peers: Dict[DHTID, Endpoint] = response[key_id][-1]
             node_to_addr.update(peers)
-            return list(peers.keys()), False  # False means "do not interrupt beam search"
+            return list(peers.keys()), False  # False means "do not interrupt search"
 
         nearest_nodes, visited_nodes = await traverse_dht(
-            query_id=query_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
+            query_id=key_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
             get_neighbors=get_neighbors, visited_nodes=(self.node_id,))
 
         if not exclude_self:
-            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query_id.xor_distance)[:k_nearest]
+            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=key_id.xor_distance)[:k_nearest]
             node_to_addr[self.node_id] = (LOCALHOST, self.port)
 
         return OrderedDict((node, node_to_addr[node]) for node in nearest_nodes)
 
-    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+    async def store(self, key: DHTKey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
         """
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
-        Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
+        Optionally cache (key, value, expiration) on nodes you met along the way (see Section 2.1 end) TODO(jheuristic)
 
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
-        key_id = DHTID.generate(key)
+        key_id, value_bytes = DHTID.generate(source=key), self.serializer.dumps(value)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
-        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, key_id, value, expiration_time))
+        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, [key_id], [value_bytes], [expiration_time]))
                  for endpoint in nearest_node_to_addr.values()]
         done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
-        return any(done)
+
+        return any(store_ok for response in done for store_ok in response.result())
 
     async def get(self, key: DHTKey, sufficient_expiration_time: Optional[DHTExpiration] = None,
                   beam_size: Optional[int] = None) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
@@ -150,14 +185,15 @@ class DHTNode:
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
             default = time of call, find any value that did not expire by the time of call
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
+        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
         :returns: value and its expiration time. If nothing is found , returns (None, None).
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
         key_id = DHTID.generate(key)
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
-        latest_value, latest_expiration, latest_node_id = None, -float('inf'), None
-        node_to_addr, nodes_checked_for_value = dict(), set()
+        latest_value_bytes, latest_expiration, latest_node_id = b'', -float('inf'), None
+        node_to_addr, nodes_checked_for_value, nearest_nodes = dict(), set(), []
         should_cache = False  # True if found value in DHT that is newer than local value
 
         # Option A: value can be stored in our local cache
@@ -165,7 +201,7 @@ class DHTNode:
         if maybe_expiration is None:
             maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
         if maybe_expiration is not None and maybe_expiration > latest_expiration:
-            latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
+            latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
             # TODO(jheuristic) we may want to run background beam search to update our cache
         nodes_checked_for_value.add(self.node_id)
 
@@ -175,12 +211,16 @@ class DHTNode:
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
             async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
-                nonlocal latest_value, latest_expiration, node_to_addr, nodes_checked_for_value
-                maybe_value, maybe_expiration, peers = await self.protocol.call_find_value(node_to_addr[node], key_id)
-                node_to_addr.update(peers)
+                nonlocal latest_value_bytes, latest_expiration, latest_node_id, node_to_addr, nodes_checked_for_value
+                response = await self.protocol.call_find(node_to_addr[node], [key_id])
                 nodes_checked_for_value.add(node)
+                if not response or key_id not in response:
+                    return [], False
+
+                maybe_value, maybe_expiration, peers = response[key_id]
+                node_to_addr.update(peers)
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
+                    latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
                 should_interrupt = (latest_expiration >= sufficient_expiration_time)
                 return list(peers.keys()), should_interrupt
 
@@ -193,13 +233,14 @@ class DHTNode:
         # Option C: didn't find good-enough value in beam search, make a last-ditch effort to find it in unvisited nodes
         if latest_expiration < sufficient_expiration_time:
             nearest_unvisited = [node_id for node_id in nearest_nodes if node_id not in nodes_checked_for_value]
-            tasks = [self.protocol.call_find_value(node_to_addr[node_id], key_id) for node_id in nearest_unvisited]
+            tasks = [self.protocol.call_find(node_to_addr[node_id], [key_id]) for node_id in nearest_unvisited]
             pending_tasks = set(tasks)
             for task in asyncio.as_completed(tasks):
                 pending_tasks.remove(task)
-                maybe_value, maybe_expiration, _ = await task
+                if not task.result() or key_id not in task.result():
+                    maybe_value, maybe_expiration, _ = task.result()[key_id]
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value, latest_expiration = maybe_value, maybe_expiration
+                    latest_value_bytes, latest_expiration = maybe_value, maybe_expiration
                     if latest_expiration >= sufficient_expiration_time:
                         break
             for task in pending_tasks:
@@ -208,19 +249,21 @@ class DHTNode:
 
         # step 4: we have not found entry with sufficient_expiration_time, but we may have found *something* older
         if should_cache and self.cache_locally:
-            self.protocol.cache.store(key_id, latest_value, latest_expiration)
+            self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
         if should_cache and self.cache_nearest:
             num_cached_nodes = 0
             for node_id in nearest_nodes:
                 if node_id == latest_node_id:
                     continue
                 asyncio.create_task(self.protocol.call_store(
-                    node_to_addr[node_id], key_id, latest_value, latest_expiration, in_cache=True))
+                    node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
                 num_cached_nodes += 1
                 if num_cached_nodes >= self.cache_nearest:
                     break
-
-        return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)
+        if latest_expiration != -float('inf'):
+            return self.serializer.loads(latest_value_bytes), latest_expiration
+        else:
+            return None, None
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """

+ 198 - 129
hivemind/dht/protocol.py

@@ -1,152 +1,224 @@
-import asyncio
+from __future__ import annotations
+import os
 import heapq
-from typing import Optional, List, Tuple, Dict, Iterator
-from rpcudp.protocol import RPCProtocol
+import asyncio
+import logging
+import urllib.parse
+from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
+from warnings import warn
+from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from ..utils import Endpoint, compile_grpc
+import grpc, grpc.experimental.aio
 
-from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
-from ..utils import Endpoint
+with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
+    dht_pb2, dht_grpc = compile_grpc(f_proto.read())
 
 
-class KademliaProtocol(RPCProtocol):
-    """
-    A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
-    As a side-effect, KademliaProtocol also maintains a routing table as described in
-    https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
+class DHTProtocol(dht_grpc.DHTServicer):
+    node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
+    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
+    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
 
-    See DHTNode (node.py) for a more detailed description.
+    @classmethod
+    async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
+                     cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+                     channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
+        """
+        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
+        As a side-effect, DHTProtocol also maintains a routing table as described in
+        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
 
-    :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
-     for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
-     Only the call_* methods are meant to be called publicly, e.g. from DHTNode
-     Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
-    """
+        See DHTNode (node.py) for a more detailed description.
 
-    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int, wait_timeout: float,
-                 max_concurrent_rpc: int, num_replicas: Optional[int] = None, cache_size: Optional[int] = None):
-        super().__init__(wait_timeout)
-        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas or bucket_size
-        self.rpc_semaphore = asyncio.BoundedSemaphore(value=max_concurrent_rpc)
+        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
+         for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
+         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
+         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
+        """
+        self = cls(_initialized_with_create=True)
+        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
+        self.wait_timeout, self.channel_options = wait_timeout, channel_options
+        self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
-        self.storage = LocalStorage()
-        self.cache = LocalStorage(maxsize=cache_size)
-
-    def rpc_ping(self, sender: Endpoint, sender_id_bytes: BinaryDHTID) -> BinaryDHTID:
-        """ Some dht node wants us to add it to our routing table. """
-        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
-        return bytes(self.node_id)
-
-    async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
-        """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
-        async with self.rpc_semaphore:
-            responded, response = await self.ping(recipient, bytes(self.node_id))
-        recipient_node_id = DHTID.from_bytes(response) if responded else None
-        asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-        return recipient_node_id
-
-    def rpc_store(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID,
-                  value: DHTValue, expiration_time: DHTExpiration, in_cache: bool) -> Tuple[bool, BinaryDHTID]:
-        """ Some node wants us to store this (key, value) pair """
-        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
-        if in_cache:
-            store_accepted = self.cache.store(DHTID.from_bytes(key_bytes), value, expiration_time)
+
+        if listen:  # set up server to process incoming rpc requests
+            grpc.experimental.aio.init_grpc_aio()
+            self.server = grpc.experimental.aio.server(**kwargs)
+            dht_grpc.add_DHTServicer_to_server(self, self.server)
+
+            found_port = self.server.add_insecure_port(listen_on)
+            assert found_port != 0, f"Failed to listen to {listen_on}"
+            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=found_port)
+            self.port = found_port
+            await self.server.start()
+        else:  # not listening to incoming requests, client-only mode
+            # note: use empty node_info so peers wont add you to their routing tables
+            self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
+            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
+                warn(f"DHTProtocol has no server (due to listen=False), listen_on"
+                     f"and kwargs have no effect (unused kwargs: {kwargs})")
+        return self
+
+    def __init__(self, *, _initialized_with_create=False):
+        """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
+        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        super().__init__()
+
+    async def shutdown(self, timeout=None):
+        """ Process existing requests, close all connections and stop the server """
+        if self.server:
+            await self.server.stop(timeout)
         else:
-            store_accepted = self.storage.store(DHTID.from_bytes(key_bytes), value, expiration_time)
-        return store_accepted, bytes(self.node_id)
+            warn("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
-    async def call_store(self, recipient: Endpoint, key: DHTID, value: DHTValue,
-                         expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
-        """
-        Ask a recipient to store (key, value) pair until expiration time or update their older value
+    def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
+        """ get a DHTStub that sends requests to a given peer """
+        channel = grpc.experimental.aio.insecure_channel(peer, options=self.channel_options)
+        return dht_grpc.DHTStub(channel)
 
-        :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
+    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
         """
-        async with self.rpc_semaphore:
-            responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
-                                                   value, expiration_time, in_cache)
-        if responded:
-            store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
-            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-            return store_accepted
-        return None
-
-    def rpc_find_node(self, sender: Endpoint, sender_id_bytes: BinaryDHTID,
-                      query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        Get peer's node id and add him to the routing table. If peer doesn't respond, return None
+        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
+
+        :return: node's DHTID, if peer responded and decided to send his node_id
         """
-        Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
+        try:
+            peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            peer_info = None
+        responded = bool(peer_info and peer_info.node_id)
+        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
+        asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
+        return peer_id
+
+    async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext):
+        """ Some node wants us to add it to our routing table. """
+        if peer_info.node_id and peer_info.rpc_port:
+            sender_id = DHTID.from_bytes(peer_info.node_id)
+            peer_url = urllib.parse.urlparse(context.peer())
+            address = peer_url.path[:peer_url.path.rindex(':')]
+            asyncio.create_task(self.update_routing_table(sender_id, f"{address}:{peer_info.rpc_port}"))
+        return self.node_info
 
-        :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
-         also returns our own node id for routing table maintenance
+    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
+                         expirations: Union[DHTExpiration, Sequence[DHTExpiration]],
+                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
         """
-        query_id, sender_id = DHTID.from_bytes(query_id_bytes), DHTID.from_bytes(sender_id_bytes)
-        asyncio.ensure_future(self.update_routing_table(sender_id, sender))
-        peer_ids_and_addr = self.routing_table.get_nearest_neighbors(query_id, k=self.bucket_size, exclude=sender_id)
-        return [(bytes(peer_id), peer_addr) for peer_id, peer_addr in peer_ids_and_addr], bytes(self.node_id)
+        Ask a recipient to store several (key, value : expiration) items or update their older value
 
-    async def call_find_node(self, recipient: Endpoint, query_id: DHTID) -> Dict[DHTID, Endpoint]:
+        :param peer: request this peer to store the data
+        :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
+        :param values: a list of N serialized values (bytes) for each respective key
+        :param expirations: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
+        :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
+         until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
+
+        :return: list of [True / False] True = stored, False = failed (found newer value or no response)
+         if peer did not respond (e.g. due to timeout or congestion), returns None
         """
-        Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
-         it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
+        in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
+        in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
+        expirations = [expirations] * len(keys) if isinstance(expirations, DHTExpiration) else expirations
+        keys, values, expirations, in_cache = map(list, [keys, values, expirations, in_cache])
+        assert len(keys) == len(values) == len(expirations) == len(in_cache), "Data is not aligned"
+        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
+                                             expiration=expirations, in_cache=in_cache, peer=self.node_info)
+        try:
+            response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+            if response.peer and response.peer.node_id:
+                peer_id = DHTID.from_bytes(response.peer.node_id)
+                asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
+            return response.store_ok
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            return [False] * len(keys)
+
+    async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
+        """ Some node wants us to store this (key, value) pair """
+        if request.peer:  # if requested, add peer to the routing table
+            asyncio.create_task(self.rpc_ping(request.peer, context))
+        assert len(request.keys) == len(request.values) == len(request.expiration) == len(request.in_cache)
+        response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
+        for key_bytes, value_bytes, expiration_time, in_cache in zip(
+                request.keys, request.values, request.expiration, request.in_cache):
+            local_memory = self.cache if in_cache else self.storage
+            response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
+        return response
 
-        :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
+    async def call_find(self, peer: Endpoint, keys: Sequence[DHTID]) -> \
+            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
         """
-        async with self.rpc_semaphore:
-            responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
-        if responded:
-            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
-            # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
-            recipient_node_id = DHTID.from_bytes(response[1])
-            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-            return peers
-        return {}
-
-    def rpc_find_value(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID) -> \
-            Tuple[Optional[DHTValue], Optional[DHTExpiration], List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        Request keys from a peer. For each key, look for its (value, expiration time) locally and
+         k additional peers that are most likely to have this key (ranked by XOR distance)
+
+        :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
+         value: value stored by the recipient with that key, or None if peer doesn't have this value
+         expiration time: expiration time of the returned value, None if no value was found
+         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
+         If peer didn't respond, returns None
         """
-        Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
-         Either way, return :bucket_size: nearest neighbors to that node.
+        keys = list(keys)
+        find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
+        try:
+            response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+            if response.peer and response.peer.node_id:
+                peer_id = DHTID.from_bytes(response.peer.node_id)
+                asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
+            assert len(response.values) == len(response.expiration) == len(response.nearest) == len(keys), \
+                "DHTProtocol: response is not aligned with keys"
+
+            output = {}  # unpack data without special NOT_FOUND_* values
+            for key, value, expiration, nearest in zip(keys, response.values, response.expiration, response.nearest):
+                value = value if value != _NOT_FOUND_VALUE else None
+                expiration = expiration if expiration != _NOT_FOUND_EXPIRATION else None
+                nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
+                output[key] = (value, expiration, nearest)
+            return output
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
 
-        :returns: (value or None if we have no value, nearest neighbors, our own dht id)
-        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
+    async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
         """
-        maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
-        cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
-        if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
-            maybe_value, maybe_expiration = cached_value, cached_expiration
-        nearest_neighbors, my_id = self.rpc_find_node(sender, sender_id_bytes, key_bytes)
-        return maybe_value, maybe_expiration, nearest_neighbors, my_id
-
-    async def call_find_value(self, recipient: Endpoint, key: DHTID) -> \
-            Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
+        Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
+        Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         """
-        Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
+        if request.peer:  # if requested, add peer to the routing table
+            asyncio.create_task(self.rpc_ping(request.peer, context))
 
-        :returns: (optional value, optional expiration time, and neighbors)
-         value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
-         expiration time: expiration time of the returned value, None if no value was found
-         neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
-        :note: if no response, returns None, None, {}
-        """
-        async with self.rpc_semaphore:
-            responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
-        if responded:
-            (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
-            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
-            asyncio.ensure_future(self.update_routing_table(recipient_id, recipient, responded=responded))
-            return value, expiration_time, peers
-        return None, None, {}
-
-    async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
+        response = dht_pb2.FindResponse(values=[], expiration=[], nearest=[], peer=self.node_info)
+        for key_id in map(DHTID.from_bytes, request.keys):
+            maybe_value, maybe_expiration = self.storage.get(key_id)
+            cached_value, cached_expiration = self.cache.get(key_id)
+            if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
+                maybe_value, maybe_expiration = cached_value, cached_expiration
+            peer_ids, endpoints = zip(*self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)))
+
+            response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
+            response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
+            response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
+        return response
+
+    async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
         """
         This method is called on every incoming AND outgoing request to update the routing table
 
-        :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
+        :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
           For incoming requests, this should always be True
         """
+        node_id = node_id if node_id is not None else self.routing_table.get_id(peer_endpoint)
         if responded:  # incoming request or outgoing request with response
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
+                data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
                 for key, value, expiration in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
@@ -155,29 +227,26 @@ class KademliaProtocol(RPCProtocol):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        asyncio.create_task(self.call_store(addr, key, value, expiration))
+                        data_to_send.append((key, value, expiration))
+                if data_to_send:
+                    asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 
-            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
+            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
             if maybe_node_to_ping is not None:
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
                 asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
 
-        else:  # outgoing request and peer did not respond
+        else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
 
-    def _accept_response(self, msg_id, data, address):
-        """ Override for RPCProtocol._accept_response to handle cancelled tasks """
-        future, timeout = self._outstanding[msg_id]
-        if future.cancelled():
-            timeout.cancel()
-            del self._outstanding[msg_id]
-        else:
-            super()._accept_response(msg_id, data, address)
+
+_NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values to represent that a value was not found
 
 
 class LocalStorage:
+    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.data = dict()
@@ -192,7 +261,7 @@ class LocalStorage:
             if self.key_to_heap[key] == heap_entry:
                 del self.data[key], self.key_to_heap[key]
 
-    def store(self, key: DHTID, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
         """
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         :returns: True if new value was stored, False it was rejected (current value is newer)
@@ -210,14 +279,14 @@ class LocalStorage:
         self.remove_outdated()
         return True
 
-    def get(self, key: DHTID) -> (Optional[DHTValue], Optional[DHTExpiration]):
+    def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
         self.remove_outdated()
         if key in self.data:
             return self.data[key]
         return None, None
 
-    def items(self) -> Iterator[Tuple[DHTID, DHTValue, DHTExpiration]]:
+    def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self.remove_outdated()
         return ((key, value, expiration) for key, (value, expiration) in self.data.items())

+ 9 - 2
hivemind/dht/routing.py

@@ -41,7 +41,7 @@ class RoutingTable:
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
 
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
-        :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
+        :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
           the start of the table or remove that node and replace it with
         """
@@ -66,6 +66,12 @@ class RoutingTable:
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
 
+    def get(self, node_id: DHTID, default=None) -> Optional[Endpoint]:
+        return self[node_id] if node_id in self else default
+
+    def get_id(self, peer: Endpoint, default=None) -> Optional[DHTID]:
+        return None #TODO(jheuristic)
+
     def __getitem__(self, node_id: DHTID) -> Endpoint:
         return self.buckets[self.get_bucket_index(node_id)][node_id]
 
@@ -174,6 +180,7 @@ class KBucket:
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
         for uid, endpoint in self.nodes_to_addr.items():
             if uid not in self.nodes_requested_for_ping:
+                self.nodes_requested_for_ping.add(uid)
                 return uid, endpoint
 
     def __getitem__(self, node_id: DHTID) -> Endpoint:
@@ -272,5 +279,5 @@ class DHTID(int):
         return self.to_bytes()
 
 
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID = Any, Any, float, bytes  # flavour types
+DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time

+ 5 - 5
hivemind/runtime/expert_backend.py

@@ -4,7 +4,7 @@ import torch
 from torch import nn
 
 from .task_pool import TaskPool
-from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map
+from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
 
 
 class ExpertBackend(nn.Module):
@@ -33,9 +33,9 @@ class ExpertBackend(nn.Module):
     """
 
     def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
-                 args_schema: Tuple[BatchTensorProto, ...] = None,
-                 kwargs_schema: Dict[str, BatchTensorProto] = None,
-                 outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
+                 args_schema: Tuple[BatchTensorDescriptor, ...] = None,
+                 kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
+                 outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
                  **kwargs):
         super().__init__()
         self.expert, self.opt, self.name = expert, opt, name
@@ -50,7 +50,7 @@ class ExpertBackend(nn.Module):
             dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
-            outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs)
+            outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
         self.outputs_schema = outputs_schema
         self.forward_schema = (self.args_schema, self.kwargs_schema)

+ 3 - 2
hivemind/utils/__init__.py

@@ -1,8 +1,9 @@
 from .connection import *
 from .data import *
 from .nested import *
-from .proto import *
+from .tensor_descr import *
 from .serializer import *
 from .shared_future import *
 from .threading import *
-from .autograd import *
+from .autograd import *
+from .grpc import *

+ 2 - 2
hivemind/utils/connection.py

@@ -3,7 +3,7 @@ from contextlib import AbstractContextManager, closing
 from typing import Tuple
 
 Hostname, Port = str, int  # flavour types
-Endpoint = Tuple[Hostname, Port]  # https://networkengineering.stackexchange.com/a/9435
+Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 
 
@@ -13,7 +13,7 @@ class Connection(AbstractContextManager):
 
     __slots__ = ('conn', 'addr')
 
-    def __init__(self, conn: socket, addr: Endpoint):
+    def __init__(self, conn: socket, addr: Tuple[Hostname, Port]):
         self.conn, self.addr = conn, addr
 
     @staticmethod

+ 44 - 0
hivemind/utils/grpc.py

@@ -0,0 +1,44 @@
+"""
+Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
+"""
+import functools
+import os
+import sys
+import tempfile
+from typing import Tuple
+from argparse import Namespace
+import grpc_tools.protoc
+
+
+@functools.lru_cache(maxsize=None)
+def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
+    """
+    Compiles and loads grpc protocol defined by protobuf string
+
+    :param proto: protocol buffer code as a string, as in open('file.proto').read()
+    :param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath'
+    :returns: messages, services protobuf
+    """
+    base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto')
+
+    with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir:
+        proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir)
+        with open(proto_path, 'w') as fproto:
+            fproto.write(proto)
+
+        cli_args = (
+            grpc_tools.protoc.__file__, f"-I{base_include}",
+            f"--python_out={build_dir}", f"--grpc_python_out={build_dir}",
+            f"-I{build_dir}", *args, os.path.basename(proto_path))
+        code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args])
+        if code:  # hint: if you get this error in jupyter, run in console for richer error message
+            raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
+
+        try:
+            sys.path.append(build_dir)
+            pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2'
+            messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc')
+            return messages, services
+        finally:
+            if sys.path.pop() != build_dir:
+                raise ImportError("Something changed sys.path while compile_grpc was in progress.")

+ 32 - 9
hivemind/utils/serializer.py

@@ -1,41 +1,64 @@
+""" A unified interface for several common serialization methods """
 import pickle
 from io import BytesIO
 
 import joblib
 import torch
+import umsgpack
 
 
-class JoblibSerializer:
+class SerializerBase:
+    @staticmethod
+    def dumps(obj: object) -> bytes:
+        raise NotImplementedError()
+
+    @staticmethod
+    def loads(buf: bytes) -> object:
+        raise NotImplementedError()
+
+
+class JoblibSerializer(SerializerBase):
 
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         s = BytesIO()
         joblib.dump(obj, s)
         return s.getvalue()
 
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return joblib.load(BytesIO(buf))
 
 
-class PickleSerializer:
+class PickleSerializer(SerializerBase):
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
 
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return pickle.loads(buf)
 
 
-class PytorchSerializer:
+class PytorchSerializer(SerializerBase):
 
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         s = BytesIO()
         torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
         return s.getvalue()
 
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return torch.load(BytesIO(buf))
+
+
+class MSGPackSerializer(SerializerBase):
+
+    @staticmethod
+    def dumps(obj: object) -> bytes:
+        return umsgpack.dumps(obj, use_bin_type=False)
+
+    @staticmethod
+    def loads(buf: bytes) -> object:
+        return umsgpack.loads(buf, raw=False)

+ 3 - 3
hivemind/utils/proto.py → hivemind/utils/tensor_descr.py

@@ -6,12 +6,12 @@ DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
 @dataclass(init=True, repr=True, frozen=True)
-class ProtoBase:
+class DescriptorBase:
     pass
 
 
 @dataclass(init=True, repr=True, frozen=True)
-class TensorProto(ProtoBase):
+class TensorDescriptor(DescriptorBase):
     size: tuple
     dtype: torch.dtype = None
     layout: torch.layout = torch.strided
@@ -34,7 +34,7 @@ class TensorProto(ProtoBase):
 
 
 @dataclass(repr=True, frozen=True)
-class BatchTensorProto(TensorProto):
+class BatchTensorDescriptor(TensorDescriptor):
     """ torch Tensor with a variable 0-th dimension, used to describe batched data """
 
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size

+ 1 - 0
hivemind/utils/threading.py

@@ -65,3 +65,4 @@ def run_and_await_k(jobs: List[callable], k: int,
             future.cancel()
             outputs[index] = future.result() if not future.exception() else future.exception()
     return outputs
+

+ 4 - 1
requirements.txt

@@ -3,6 +3,9 @@ joblib>=0.13
 numpy>=1.17
 requests>=2.22.0
 tqdm
-rpcudp>=4.0.0
 prefetch_generator>=1.0.1
 pytest
+umsgpack
+grpcio
+grpcio-tools>=1.30.0
+aiologger>=0.5.0

+ 2 - 2
tests/benchmark_throughput.py

@@ -64,8 +64,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
                                                            expert=expert, opt=torch.optim.Adam(expert.parameters()),
-                                                           args_schema=(hivemind.BatchTensorProto(hid_dim),),
-                                                           outputs_schema=hivemind.BatchTensorProto(hid_dim),
+                                                           args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                                                           outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                                                            max_batch_size=max_batch_size,
                                                            )
         timestamps['created_experts'] = time.perf_counter()

+ 183 - 174
tests/test_dht.py

@@ -4,7 +4,6 @@ import multiprocessing as mp
 import random
 import heapq
 import uuid
-from functools import partial
 from itertools import chain
 from typing import Optional
 import numpy as np
@@ -13,198 +12,208 @@ import hivemind
 from typing import List, Dict
 
 from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.protocol import LocalStorage
 
 
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
-                          ping: Optional[hivemind.Endpoint] = None):
-    loop = asyncio.new_event_loop()
-    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128)
-    listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
-    transport, protocol = loop.run_until_complete(listen)
+def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
+    loop = asyncio.get_event_loop()
+    protocol = loop.run_until_complete(DHTProtocol.create(
+        dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
+
+    assert protocol.port == port
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
 
     if ping is not None:
         loop.run_until_complete(protocol.call_ping(ping))
     started.set()
-    loop.run_forever()
+    loop.run_until_complete(protocol.server.wait_for_termination())
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
 def test_kademlia_protocol():
-    try:
-        # create the first peer
-        peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-        peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
-        peer1_proc.start(), peer1_started.wait()
-
-        # create another peer that connects to the first peer
-        peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-        peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
-                                kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True)
-        peer2_proc.start(), peer2_started.wait()
-
-        port = hivemind.find_open_port()
-        loop = asyncio.new_event_loop()
-        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5,
-                           max_concurrent_rpc=128)
-        listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
-        transport, protocol = loop.run_until_complete(listen)
-        print(f"Self id={protocol.node_id} port={port}", flush=True)
-
-        assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id
-
-        key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-        assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration))
-
-        # peer 1 must know about peer 2
-        nodes_found = loop.run_until_complete(
-            protocol.call_find_node(('127.0.0.1', peer1_port), key))
-        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \
-            f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}"
-
-        # peer 2 must know about peer 1
-        nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key))
-        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \
-            f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}"
-
-        recv_value, recv_expiration, recv_peers = loop.run_until_complete(
-            protocol.call_find_value(('127.0.0.1', peer1_port), key))
-        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                                                                      f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
-        print(recv_peers, nodes_found)
-        assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node"
-        print("Kademlia test finished sucessfully!")
-
-    finally:
-        peer1_proc.terminate()
-        peer2_proc.terminate()
-
-
-def run_node(node_id, port, peers, status_pipe: mp.Pipe):
+    # create the first peer
+    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
+    peer1_proc.start(), peer1_started.wait()
+
+    # create another peer that connects to the first peer
+    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
+                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
+    peer2_proc.start(), peer2_started.wait()
+
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+
+        loop = asyncio.get_event_loop()
+        for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+            protocol = loop.run_until_complete(DHTProtocol.create(
+                DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+            print(f"Self id={protocol.node_id}", flush=True)
+
+            assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+
+            key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+            store_ok = loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            )
+            assert all(store_ok), "DHT rejected a trivial store"
+
+            # peer 1 must know about peer 2
+            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
+            assert recv_id == peer2_id and recv_endpoint == f"{LOCALHOST}:{peer2_port}", \
+                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+
+            assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+                f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+            # peer 2 must know about peer 1, but not have a *random* nonexistent value
+            dummy_key = DHTID.generate()
+            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
+            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
+            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
+                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+
+            # cause a non-response by querying a nonexistent peer
+            dummy_port = hivemind.find_open_port()
+            assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
+
+            if listen:
+                loop.run_until_complete(protocol.shutdown())
+            print("DHTProtocol test finished sucessfully!")
+            test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
+    peer1_proc.terminate()
+    peer2_proc.terminate()
+
+
+def run_node(node_id, peers, status_pipe: mp.Pipe):
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-    asyncio.set_event_loop(asyncio.new_event_loop())
-    try:
-        node = DHTNode(node_id, port, initial_peers=peers)
-        status_pipe.send('STARTED')
-        while True:
-            asyncio.get_event_loop().run_forever()
-    except BaseException as e:
-        status_pipe.send(e)  # report exception to master
-        if not isinstance(e, OSError):
-            raise e
+        asyncio.set_event_loop(asyncio.new_event_loop())
+    loop = asyncio.get_event_loop()
+    node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
+    status_pipe.send(node.port)
+    while True:
+        loop.run_forever()
 
 
 def test_dht():
     # create dht with 50 nodes + your 51-st node
     dht: Dict[Endpoint, DHTID] = {}
     processes: List[mp.Process] = []
-    port_fails, max_port_fails = 0, 10
 
-    while len(dht) < 50:
+    for i in range(50):
         node_id = DHTID.generate()
         peers = random.sample(dht.keys(), min(len(dht), 5))
-        port = hivemind.find_open_port()
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
-        proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True)
+        proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
         proc.start()
-
-        status = pipe_recv.recv()
-        if status == 'STARTED':
-            processes.append(proc)
-            dht[(LOCALHOST, port)] = node_id
-        else:
-            assert isinstance(status, BaseException)
-            proc.terminate()
-            if isinstance(status, OSError):  # port already in use. It just happens sometimes.
-                port_fails += 1
-                if port_fails > max_port_fails:
-                    raise OSError("Too many 'Address already in use' errors.")
-            else:
-                raise ValueError(f"Failed to create node due to an error {status}, see traceback above")
-
-    loop = asyncio.get_event_loop()
-    me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0)  # port=0 means os-specified port
-
-    # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1))
-    assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
-
-    # test 2: find others
-    for i in range(10):
-        ref_endpoint, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1))
-        assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
-
-    # test 3: find neighbors to random nodes
-    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
-    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
-    all_node_ids = list(dht.values())
-
-    for i in range(100):
-        query_id = DHTID.generate()
-        k_nearest = random.randint(1, 20)
-        exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
-        nearest_nodes = list(nearest)  # keys from ordered dict
-
-        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
-        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
-
-        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
-        if exclude_self and me.node_id in ref_nearest:
-            ref_nearest.remove(me.node_id)
-        if len(ref_nearest) > k_nearest:
-            ref_nearest.pop()
-
-        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
-        accuracy_denominator += 1
-
-        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
-        jaccard_denominator += k_nearest
-
-    accuracy = accuracy_numerator / accuracy_denominator
-    print("Top-1 accuracy:", accuracy)  # should be 98-100%
-    jaccard_index = jaccard_numerator / jaccard_denominator
-    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-
-    # test 4: find all nodes
-    nearest = loop.run_until_complete(
-        me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100))
-    assert len(nearest) == len(dht) + 1
-    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
-
-    # test 5: node without peers
-    other_node = hivemind.dht.node.DHTNode()
-    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
-    assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
-    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
-    assert len(nearest) == 0
-
-    # test 6 store and get value
-    true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-    val, expiration_time = loop.run_until_complete(me.get("mykey"))
-    assert expiration_time == true_time, "Wrong time"
-    assert val == ["Value", 10], "Wrong value"
-
-    # terminate remaining processes
+        port = pipe_recv.recv()
+        processes.append(proc)
+        dht[f"{LOCALHOST}:{port}"] = node_id
+
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+        loop = asyncio.get_event_loop()
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5)))
+
+        # test 1: find self
+        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))
+        assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
+
+        # test 2: find others
+        for i in range(10):
+            ref_endpoint, query_id = random.choice(list(dht.items()))
+            nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=query_id, k_nearest=1))
+            assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
+
+        # test 3: find neighbors to random nodes
+        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
+        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
+        all_node_ids = list(dht.values())
+
+        for i in range(100):
+            query_id = DHTID.generate()
+            k_nearest = random.randint(1, 20)
+            exclude_self = random.random() > 0.5
+            nearest = loop.run_until_complete(
+                me.find_nearest_nodes(key_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
+            nearest_nodes = list(nearest)  # keys from ordered dict
+
+            assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
+            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
+            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
+
+            ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
+            if exclude_self and me.node_id in ref_nearest:
+                ref_nearest.remove(me.node_id)
+            if len(ref_nearest) > k_nearest:
+                ref_nearest.pop()
+
+            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
+            accuracy_denominator += 1
+
+            jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
+            jaccard_denominator += k_nearest
+
+        accuracy = accuracy_numerator / accuracy_denominator
+        print("Top-1 accuracy:", accuracy)  # should be 98-100%
+        jaccard_index = jaccard_numerator / jaccard_denominator
+        print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+
+        # test 4: find all nodes
+        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=DHTID.generate(), k_nearest=len(dht) + 100))
+        assert len(nearest) == len(dht) + 1
+        assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
+
+        # test 5: node without peers
+        other_node = loop.run_until_complete(DHTNode.create())
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
+        assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
+        assert len(nearest) == 0
+
+        # test 6 store and get value
+        true_time = get_dht_time() + 1200
+        assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+        for node in [me, other_node]:
+            val, expiration_time = loop.run_until_complete(me.get("mykey"))
+            assert expiration_time == true_time, "Wrong time"
+            assert val == ["Value", 10], "Wrong value"
+
+        test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
     for proc in processes:
         proc.terminate()
 
 
 def test_hivemind_dht():
-    peers = [hivemind.dht.DHT(start=True)]
+    peers = [hivemind.DHT(start=True)]
     for i in range(10):
-        neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))]
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         peers.append(hivemind.DHT(*neighbors_i, start=True))
 
     you: hivemind.dht.DHT = random.choice(peers)
@@ -241,37 +250,37 @@ def test_hivemind_dht():
 
 def test_store():
     d = LocalStorage()
-    d.store("key", "val", get_dht_time() + 10)
-    assert d.get("key")[0] == "val", "Wrong value"
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 10)
+    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
 
 
 def test_get_expired():
     d = LocalStorage()
-    d.store("key", "val", get_dht_time() + 1)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 1)
     time.sleep(2)
-    assert d.get("key") == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     print("Test get expired passed")
 
 
 def test_get_empty():
     d = LocalStorage()
-    assert d.get("key") == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
     print("Test get expired passed")
 
 
 def test_change_expiration_time():
     d = LocalStorage()
-    d.store("key", "val1", get_dht_time() + 2)
-    d.store("key", "val2", get_dht_time() + 200)
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 2)
+    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
     time.sleep(4)
-    assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table"
+    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
     print("Test change expiration time passed")
 
 
 def test_maxsize_cache():
     d = LocalStorage(maxsize=1)
-    d.store("key1", "val1", get_dht_time() + 1)
-    d.store("key2", "val2", get_dht_time() + 200)
-    assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept"
-    assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
+    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"

+ 10 - 4
tests/test_routing.py

@@ -3,6 +3,7 @@ import heapq
 import operator
 from itertools import chain, zip_longest
 
+from hivemind import LOCALHOST
 from hivemind.dht.routing import RoutingTable, DHTID
 from hivemind.utils.serializer import PickleSerializer
 
@@ -37,8 +38,8 @@ def test_routing_table_basic():
 
     for phony_neighbor_port in random.sample(range(10000), 100):
         phony_id = DHTID.generate()
-        routing_table.add_or_update_node(phony_id, ('localhost', phony_neighbor_port))
-        assert routing_table[phony_id] == ('localhost', phony_neighbor_port)
+        routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
+        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
 
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     for bucket in routing_table.buckets:
@@ -56,7 +57,7 @@ def test_routing_table_parameters():
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
-            routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port))
+            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
         for bucket in routing_table.buckets:
             assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
@@ -70,8 +71,13 @@ def test_routing_table_search():
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
         num_added = 0
+        total_nodes = 0
+
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
-            num_added += routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port)) is None
+            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
+            new_total = sum(len(bucket.nodes_to_addr) for bucket in routing_table.buckets)
+            num_added += new_total > total_nodes
+            total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
     
         all_active_neighbors = list(chain(

+ 3 - 3
tests/test_utils/run_server.py

@@ -61,9 +61,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
 
     sample_input = name_to_input[expert_cls](4, hidden_dim)
     if isinstance(sample_input, tuple):
-        args_schema = tuple(hivemind.BatchTensorProto.from_tensor(arg) for arg in sample_input)
+        args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
     else:
-        args_schema = (hivemind.BatchTensorProto.from_tensor(sample_input),)
+        args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
 
     # initialize experts
     experts = {}
@@ -73,7 +73,7 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
-                                                     outputs_schema=hivemind.BatchTensorProto(hidden_dim),
+                                                     outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
                                                      max_batch_size=max_batch_size,
                                                      )
     # actually start server