Jelajahi Sumber

[part 2] grpc-based dht (#51)

* add option to share keys new peers that *should* be sharing them (improves data availability if a lot of peers join)

* * added option to NOT refresh table (new default)
* initial dht crawl is no longer blocking

* sphinx-friendly excape char

* pass cache params to kademliaprotocol

* typo

* rpc congestion

* rpc congestion

* rpc congestion

* add max concurrent rpc

* add max concurrent rpc

* fix bug in welcome protocol: previously dht peers always considered each other "new nodes" and sent EVERYTHING to EVERYONE on each rpc call. Now DHT nodes will only request store on .store call OR when a new peer knocks on their DHT

* increase to 128 concurrent rpc

* await dht traversal in bootstrap

* await dht traversal in bootstrap

* minor comment

* rename TensorProto -> TensorDescriptor to avoid name conflicts with protobuf

* add grpc requirements (break tests for now)

* add grpc requirements (break tests for now)

* add grpc requirements (break tests for now)

* minor bugfix: always add peer to nodes requested for ping

* [this breaks tests]
* implement DHTProtocol via gRPC
* DHTProtocol now stores bytes only (not enforcing msgpack)
* add grpc requirements

* update node.py to grpc DHTProtocol

* reminder to implement congestion

* reminder to implement congestion

* format

* temporary patch: adapt bulk RPCs to individual search

* temporary patch: adapt bulk RPCs to individual search

* remane KademliaProtocol -> DHTProtocol (rationale: no longer kademlia compliant)

* semicolon

* pep

* pep

* pep

* KademliaProtocol -> DHTProtocol

* update tests for new eviction policy (do not evict the same node twice)

* init aio in node constructor

* rename KademliaProtocol => DHTProtocol everywhere

* minor sphinx formatting fix

* partially update test_dht

* test: typo fix

* test: typo fix

* update test_dht for new dht interface

* compile grpc from master

* compile grpc from master

* compile grpc from master

* add umsgpack to requirements

* cache compiled grpcio

* ensure umsgpack version compatibility

* remove unused imports from dht folder

* update schemes

* update schemes

* review

* review

* ensure_future => create_task

Co-authored-by: xtinkt <ant.sinitsin@gmail.com>
justheuristic 5 tahun lalu
induk
melakukan
8bded39d9b

+ 12 - 1
.circleci/config.yml

@@ -9,8 +9,19 @@ jobs:
     steps:
     steps:
       - checkout
       - checkout
       - python/load-cache
       - python/load-cache
+      - run:
+          command: |
+            if [[ $(pip show grpcio | grep Version) != *1.31* ]]; then
+              git clone https://github.com/grpc/grpc --recurse-submodules
+              cd grpc
+              sudo pip install -r requirements.txt
+              export GRPC_PYTHON_BUILD_WITH_CYTHON=1
+              sudo pip install .
+              cd -
+            fi
+          name: compile-grpc  # remove this command when v1.31 becomes available via pip install -r requirements.txt
+      - run: sudo pip install codecov pytest grpcio-tools
       - python/install-deps
       - python/install-deps
-      - run: sudo pip install codecov pytest
       - python/save-cache
       - python/save-cache
       - run:
       - run:
           command: sudo python setup.py develop
           command: sudo python setup.py develop

TEMPAT SAMPAH
docs/_static/dht.odp


TEMPAT SAMPAH
docs/_static/dht.png


+ 1 - 1
docs/modules/dht.rst

@@ -20,7 +20,7 @@
 
 
 .. currentmodule:: hivemind.dht.protocol
 .. currentmodule:: hivemind.dht.protocol
 
 
-.. autoclass:: KademliaProtocol
+.. autoclass:: DHTProtocol
    :members:
    :members:
    :member-order: bysource
    :member-order: bysource
 
 

+ 6 - 5
hivemind/dht/__init__.py

@@ -2,7 +2,7 @@
 This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
 This sub-module implements a node in a Kademlia-based DHT. The code is organized as follows:
  * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
  * class DHT (below) - high-level class for model training. Runs DHTNode in a background process.
  * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
  * class DHTNode (node.py) - an asyncio implementation of dht server, stores AND gets keys. Asyncio-based.
- * class KademliaProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
+ * class DHTProtocol (protocol.py) - an rpc protocol to request data from dht nodes. Asyncio-based.
 
 
 The code in this module is a modified version of https://github.com/bmuller/kademlia
 The code in this module is a modified version of https://github.com/bmuller/kademlia
 Brian, if you're reading this: THANK YOU! you're awesome :)
 Brian, if you're reading this: THANK YOU! you're awesome :)
@@ -10,13 +10,13 @@ Brian, if you're reading this: THANK YOU! you're awesome :)
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
 import warnings
 import warnings
-from typing import Tuple, List, Optional
+from typing import List, Optional
 
 
 from .node import DHTNode, DHTID, DHTExpiration
 from .node import DHTNode, DHTID, DHTExpiration
 from .routing import get_dht_time
 from .routing import get_dht_time
 
 
 from ..client import RemoteExpert
 from ..client import RemoteExpert
-from ..utils import SharedFuture, find_open_port, Hostname, Port, run_in_background
+from ..utils import SharedFuture, find_open_port, Endpoint, Port, run_in_background, LOCALHOST
 
 
 
 
 class DHT(mp.Process):
 class DHT(mp.Process):
@@ -33,7 +33,7 @@ class DHT(mp.Process):
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     EXPIRATION = 120  # anything written to DHT is considered expired after this many seconds
     make_key = "{}::{}".format
     make_key = "{}::{}".format
 
 
-    def __init__(self, *initial_peers: Tuple[Hostname, Port], port: Optional[Port] = None,
+    def __init__(self, *initial_peers: Endpoint, port: Optional[Port] = None,
                  start: bool, daemon: bool = True, **node_params):
                  start: bool, daemon: bool = True, **node_params):
         super().__init__()
         super().__init__()
         port = find_open_port() if port is None else port
         port = find_open_port() if port is None else port
@@ -52,7 +52,8 @@ class DHT(mp.Process):
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
         asyncio.set_event_loop(loop)
         asyncio.set_event_loop(loop)
 
 
-        self.node = DHTNode(initial_peers=list(self.initial_peers), port=self.port, **self.node_params)
+        self.node = loop.run_until_complete(DHTNode.create(
+            initial_peers=list(self.initial_peers), listen_on=f"{LOCALHOST}:{self.port}", **self.node_params))
         run_in_background(loop.run_forever)
         run_in_background(loop.run_forever)
         self.ready.set()
         self.ready.set()
 
 

+ 55 - 0
hivemind/dht/dht.proto

@@ -0,0 +1,55 @@
+syntax = "proto3";
+
+// this protocol defines how Hivemind nodes form a distributed hash table.
+// For more info, see https://learning-at-home.readthedocs.io/en/latest/modules/dht.html or help(hivemind.dht.DHTNode)
+
+service DHT {
+    // find out recipient's DHTID and possibly update its routing table
+    rpc rpc_ping(NodeInfo) returns (NodeInfo);
+
+    // request a node to store one or multiple data items (key - value - expiration)
+    rpc rpc_store(StoreRequest) returns (StoreResponse);
+
+    // for given keys, request values (if stored) or a list of peers that are likely to have them
+    rpc rpc_find(FindRequest) returns (FindResponse);
+}
+
+message NodeInfo {
+    // note: both node_id and port are optional: if specified, ask peer to add you to its routing table;
+    // if either node_id or port is absent, simply request recipient info (for client-only mode)
+    bytes node_id = 1;                // sender's own node id serialized with DHTID.to_bytes()
+    int32 rpc_port = 2;               // port to which sender listens for DHT RPCs
+}
+
+message StoreRequest {
+    // three lists of the same length representing dht keys, dht values and expiration
+    repeated bytes keys = 1;          // keys in the form of DHTID.generate(raw_key).to_bytes()
+    repeated bytes values = 2;        // binary-encoded value for i-th key
+    repeated double expiration = 3;   // expirations for i-th key (type = DHTExpiration)
+    repeated bool in_cache = 4;       // if in_cache[i], store i-th key in cache, else store normally
+    NodeInfo peer = 5;                // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+}
+
+message StoreResponse {
+    repeated bool store_ok = 1;       // for every key, True means store accepted, False means store rejected/failed
+    NodeInfo peer = 2;                // respondent's node id, for you to update routing table
+}
+
+message FindRequest {
+    repeated bytes keys = 1;          // a list of DHTID search keys encoded as bytes
+    NodeInfo peer = 2;                // optional, same behavior as in DHT.ping
+}
+
+message Peers {
+   // two aligned arrays: DHTIDs and Endpoints, i-th endpoint corresponds to peer with i-th node id
+   repeated bytes node_ids = 1;       // DHTID serialized with node_id.to_bytes()
+   repeated string endpoints = 2;     // e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+}
+
+message FindResponse {
+    repeated bytes values = 1;        // value for i-th key, b'' means not found locally
+    repeated double expiration = 2;   // expiration time for i-th value, only valid value is found
+    repeated Peers nearest = 3;       // peers ordered from nearest to farthest based on distance to i-th key
+    NodeInfo peer = 4;                // respondent's node id, for you to update routing table
+}
+

+ 133 - 90
hivemind/dht/node.py

@@ -1,107 +1,137 @@
+from __future__ import annotations
 import asyncio
 import asyncio
 import random
 import random
 from collections import OrderedDict
 from collections import OrderedDict
-from functools import partial
 from typing import Optional, Tuple, List, Dict
 from typing import Optional, Tuple, List, Dict
 from warnings import warn
 from warnings import warn
 
 
-from .protocol import KademliaProtocol
-from .routing import DHTID, DHTValue, DHTExpiration, DHTKey, get_dht_time
+from .protocol import DHTProtocol
+from .routing import DHTID, BinaryDHTValue, DHTExpiration, DHTKey, get_dht_time, DHTValue
 from .search import traverse_dht
 from .search import traverse_dht
-from ..utils import find_open_port, Endpoint, Hostname, Port, LOCALHOST
+from ..utils import Endpoint, LOCALHOST, MSGPackSerializer
 
 
 
 
 class DHTNode:
 class DHTNode:
     """
     """
-    A low-level class that represents a DHT participant.
-    Each DHTNode has an identifier, a local storage and access too other nodes via KademliaProtocol.
-
-    :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
-    :param port: port to which this DHTNode will listen, by default find some open port
-    :param initial_peers: connects to these peers to populate routing table, defaults to no peers
-    :param bucket_size: (k) - max number of nodes in one k-bucket. Trying to add {k+1}st node will cause a bucket to
-      either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
-      Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
-    :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
-    :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
-    :param max_concurrent_rpc: maximum number of outgoing RPC requests emitted by KademliaProtocol in parallel
-        Reduce this value if your RPC requests register no response despite the peer sending the response.
-    :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
-    :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
-        if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
-    :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
-    :param cache_locally: if True, caches all values (stored or found) in a node-local cache
-    :param cache_nearest: if above 0, whenever DHTNode finds a value, it will also store (cache) this value on this many
-        nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet.
-    :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
-    :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
-
-    :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
+    A low-level class that represents a DHT participant. Please see DHTNode.create for parameters
+    Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
+
+    :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
      For example, an expert alive timestamp that emitted by the Server responsible for that expert.
      For example, an expert alive timestamp that emitted by the Server responsible for that expert.
-     Such metadata does not require maintenance such as ensuring at least k hosts have it or (de)serialization in case
-     of node shutdown. Instead, DHTNode is designed to reduce the latency of looking up such data.
+     Such metadata does not require regular maintenance by peers, persistence on shutdown.
+     Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
 
 
-    Every (key, value) pair in this DHT has expiration_time - float number computed as get_dht_time(), default: UnixTime
-    Informally, dht nodes always prefer values with higher expiration_time and may delete any value past its expiration.
+    Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time(), UnixTime by default
+    DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
 
 
-    Formally, DHTNode follows this contract:
+    Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
 
 
-    - when asked to store(key, value, expiration_time), a node must store (key, value) at least until expiration_time
-      unless it already stores that key with greater or equal expiration_time - if so, node must keep the previous key
-    - when asked to get(key), a node must return the value with highest expiration time IF that time has not come yet
-      if expiration time is greater than current get_dht_time(), DHTNode *may* return None
+    * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
+    * store - send several (key, value, expiration) pairs to the same peer (like Kademlia STORE, but in bulk)
+    * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
+        nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
+        This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
 
 
-    """
+    Formally, DHTNode follows the following contract:
 
 
-    def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
-                 bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
-                 max_concurrent_rpc: int = 128, wait_timeout: float = 5, staleness_timeout: Optional[float] = None,
-                 bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
-                 cache_size=None, interface: Hostname = '0.0.0.0'):
+    - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
+      IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
+    - when requested to store(key: value, expiration), a node must store (key => value) at until expiration time
+      or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
+      has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
+    - when requested to store(key: value, expiration, in_cache=True), stores (key => value) in a separate "cache".
+      Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
+
+    """
+    node_id: int; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; refresh_timeout: float
+    protocol: DHTProtocol
+    serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
+
+
+    @classmethod
+    async def create(
+            cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
+            bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5, max_requests: int = 0,
+            wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
+            cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+        """
+        :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
+        :param port: port to which this DHTNode will listen, by default find some open port
+        :param initial_peers: connects to these peers to populate routing table, defaults to no peers
+        :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
+          either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
+          Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
+        :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
+        :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
+        :param max_requests: maximum number of outgoing RPC requests emitted by DHTProtocol in parallel
+          Reduce this value if your RPC requests register no response despite the peer sending the response.
+        :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
+        :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
+          if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
+        :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+        :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+        :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
+          nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
+        :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
+        :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
+          if False, this node will refuse any incoming request, effectively being only a "client"
+        :param listen_on: network interface for incoming RPCs, e.g. "0.0.0.0:1337" or "localhost:\*" or "[::]:7654"
+        :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
+          see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
+        :param kwargs: extra parameters used in grpc.aio.server
+        """
+        assert max_requests == 0, "TODO(jheuristic) implement congestion!"
+        self = cls(_initialized_with_create=True)
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
-        self.port = port = port if port is not None else find_open_port()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
         self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
         self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
-        self.staleness_timeout = staleness_timeout
+        self.refresh_timeout = refresh_timeout
+
+        self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
+                                                 cache_size, listen, listen_on, **kwargs)
+        self.port = self.protocol.port
 
 
-        # create kademlia protocol and make it listen to a port
-        loop = asyncio.get_event_loop()
-        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout,
-                                max_concurrent_rpc, num_replicas, cache_size)
-        listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
-        self.transport: asyncio.Transport = listener[0]
-        self.protocol: KademliaProtocol = listener[1]
 
 
         if initial_peers:
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             start_time = get_dht_time()
             ping_tasks = map(self.protocol.call_ping, initial_peers)
             ping_tasks = map(self.protocol.call_ping, initial_peers)
-            finished_ping_tasks, remaining_ping_tasks = loop.run_until_complete(
-                asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED))
+            finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
 
 
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
-            if remaining_ping_tasks:
-                finished_in_time, stragglers = loop.run_until_complete(
-                    asyncio.wait(remaining_ping_tasks, timeout=bootstrap_timeout - get_dht_time() + start_time))
+            if unfinished_pings:
+                finished_in_time, stragglers = await asyncio.wait(
+                    unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
                 for straggler in stragglers:
                 for straggler in stragglers:
                     straggler.cancel()
                     straggler.cancel()
-                finished_ping_tasks |= finished_in_time
+                finished_pings |= finished_in_time
 
 
-            if not finished_ping_tasks:
+            if not finished_pings:
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
 
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
-            loop.run_until_complete(asyncio.wait([loop.create_task(self.find_nearest_nodes(query_id=self.node_id)),
-                                                  asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
-                                                 return_when=asyncio.FIRST_COMPLETED))
+            await asyncio.wait([asyncio.create_task(self.find_nearest_nodes(key_id=self.node_id)),
+                                asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
+                               return_when=asyncio.FIRST_COMPLETED)
+
+        if self.refresh_timeout is not None:
+            asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
+        return self
 
 
-        if self.staleness_timeout is not None:
-            loop.create_task(self._refresh_routing_table(period=self.staleness_timeout))
+    def __init__(self, *, _initialized_with_create=False):
+        """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
+        assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
+        super().__init__()
 
 
-    async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
+    async def shutdown(self, timeout=None):
+        """ Process existing requests, close all connections and stop the server """
+        await self.protocol.shutdown(timeout)
+
+    async def find_nearest_nodes(self, key_id: DHTID, k_nearest: Optional[int] = None,
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
         """
         """
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
         Traverse the DHT and find :k_nearest: nodes to a given :query_id:, optionally :exclude_self: from the results.
@@ -112,36 +142,41 @@ class DHTNode:
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         node_to_addr = dict(
         node_to_addr = dict(
-            self.protocol.routing_table.get_nearest_neighbors(query_id, beam_size, exclude=self.node_id))
+            self.protocol.routing_table.get_nearest_neighbors(key_id, beam_size, exclude=self.node_id))
+
+        async def get_neighbors(node_id: DHTID) -> Tuple[List[DHTID], bool]:
+            response = await self.protocol.call_find(node_to_addr[node_id], [key_id])
+            if not response or key_id not in response:
+                return [], False  # False means "do not interrupt search"
 
 
-        async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
-            peers: Dict[DHTID, Endpoint] = await self.protocol.call_find_node(node_to_addr[node], query_id)
+            peers: Dict[DHTID, Endpoint] = response[key_id][-1]
             node_to_addr.update(peers)
             node_to_addr.update(peers)
-            return list(peers.keys()), False  # False means "do not interrupt beam search"
+            return list(peers.keys()), False  # False means "do not interrupt search"
 
 
         nearest_nodes, visited_nodes = await traverse_dht(
         nearest_nodes, visited_nodes = await traverse_dht(
-            query_id=query_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
+            query_id=key_id, initial_nodes=list(node_to_addr), k_nearest=k_nearest, beam_size=beam_size,
             get_neighbors=get_neighbors, visited_nodes=(self.node_id,))
             get_neighbors=get_neighbors, visited_nodes=(self.node_id,))
 
 
         if not exclude_self:
         if not exclude_self:
-            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query_id.xor_distance)[:k_nearest]
+            nearest_nodes = sorted(nearest_nodes + [self.node_id], key=key_id.xor_distance)[:k_nearest]
             node_to_addr[self.node_id] = (LOCALHOST, self.port)
             node_to_addr[self.node_id] = (LOCALHOST, self.port)
 
 
         return OrderedDict((node, node_to_addr[node]) for node in nearest_nodes)
         return OrderedDict((node, node_to_addr[node]) for node in nearest_nodes)
 
 
-    async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+    async def store(self, key: DHTKey, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
         """
         """
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
         Find beam_size best nodes to store (key, value) and store it there at least until expiration time.
-        Also cache (key, value, expiration_time) at all nodes you met along the way (see Section 2.1 end)
+        Optionally cache (key, value, expiration) on nodes you met along the way (see Section 2.1 end) TODO(jheuristic)
 
 
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         :returns: True if store succeeds, False if it fails (due to no response or newer value)
         """
         """
-        key_id = DHTID.generate(key)
+        key_id, value_bytes = DHTID.generate(source=key), self.serializer.dumps(value)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
         nearest_node_to_addr = await self.find_nearest_nodes(key_id, k_nearest=self.num_replicas, exclude_self=True)
-        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, key_id, value, expiration_time))
+        tasks = [asyncio.create_task(self.protocol.call_store(endpoint, [key_id], [value_bytes], [expiration_time]))
                  for endpoint in nearest_node_to_addr.values()]
                  for endpoint in nearest_node_to_addr.values()]
         done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
         done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
-        return any(done)
+
+        return any(store_ok for response in done for store_ok in response.result())
 
 
     async def get(self, key: DHTKey, sufficient_expiration_time: Optional[DHTExpiration] = None,
     async def get(self, key: DHTKey, sufficient_expiration_time: Optional[DHTExpiration] = None,
                   beam_size: Optional[int] = None) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
                   beam_size: Optional[int] = None) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
@@ -150,14 +185,15 @@ class DHTNode:
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
         :param sufficient_expiration_time: if the search finds a value that expires after this time,
             default = time of call, find any value that did not expire by the time of call
             default = time of call, find any value that did not expire by the time of call
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
             If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
+        :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
         :returns: value and its expiration time. If nothing is found , returns (None, None).
         :returns: value and its expiration time. If nothing is found , returns (None, None).
         :note: in order to check if get returned a value, please check (expiration_time is None)
         :note: in order to check if get returned a value, please check (expiration_time is None)
         """
         """
         key_id = DHTID.generate(key)
         key_id = DHTID.generate(key)
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
         beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
-        latest_value, latest_expiration, latest_node_id = None, -float('inf'), None
-        node_to_addr, nodes_checked_for_value = dict(), set()
+        latest_value_bytes, latest_expiration, latest_node_id = b'', -float('inf'), None
+        node_to_addr, nodes_checked_for_value, nearest_nodes = dict(), set(), []
         should_cache = False  # True if found value in DHT that is newer than local value
         should_cache = False  # True if found value in DHT that is newer than local value
 
 
         # Option A: value can be stored in our local cache
         # Option A: value can be stored in our local cache
@@ -165,7 +201,7 @@ class DHTNode:
         if maybe_expiration is None:
         if maybe_expiration is None:
             maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
             maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
         if maybe_expiration is not None and maybe_expiration > latest_expiration:
         if maybe_expiration is not None and maybe_expiration > latest_expiration:
-            latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
+            latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, self.node_id
             # TODO(jheuristic) we may want to run background beam search to update our cache
             # TODO(jheuristic) we may want to run background beam search to update our cache
         nodes_checked_for_value.add(self.node_id)
         nodes_checked_for_value.add(self.node_id)
 
 
@@ -175,12 +211,16 @@ class DHTNode:
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
 
             async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
             async def get_neighbors(node: DHTID) -> Tuple[List[DHTID], bool]:
-                nonlocal latest_value, latest_expiration, node_to_addr, nodes_checked_for_value
-                maybe_value, maybe_expiration, peers = await self.protocol.call_find_value(node_to_addr[node], key_id)
-                node_to_addr.update(peers)
+                nonlocal latest_value_bytes, latest_expiration, latest_node_id, node_to_addr, nodes_checked_for_value
+                response = await self.protocol.call_find(node_to_addr[node], [key_id])
                 nodes_checked_for_value.add(node)
                 nodes_checked_for_value.add(node)
+                if not response or key_id not in response:
+                    return [], False
+
+                maybe_value, maybe_expiration, peers = response[key_id]
+                node_to_addr.update(peers)
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
+                    latest_value_bytes, latest_expiration, latest_node_id = maybe_value, maybe_expiration, node
                 should_interrupt = (latest_expiration >= sufficient_expiration_time)
                 should_interrupt = (latest_expiration >= sufficient_expiration_time)
                 return list(peers.keys()), should_interrupt
                 return list(peers.keys()), should_interrupt
 
 
@@ -193,13 +233,14 @@ class DHTNode:
         # Option C: didn't find good-enough value in beam search, make a last-ditch effort to find it in unvisited nodes
         # Option C: didn't find good-enough value in beam search, make a last-ditch effort to find it in unvisited nodes
         if latest_expiration < sufficient_expiration_time:
         if latest_expiration < sufficient_expiration_time:
             nearest_unvisited = [node_id for node_id in nearest_nodes if node_id not in nodes_checked_for_value]
             nearest_unvisited = [node_id for node_id in nearest_nodes if node_id not in nodes_checked_for_value]
-            tasks = [self.protocol.call_find_value(node_to_addr[node_id], key_id) for node_id in nearest_unvisited]
+            tasks = [self.protocol.call_find(node_to_addr[node_id], [key_id]) for node_id in nearest_unvisited]
             pending_tasks = set(tasks)
             pending_tasks = set(tasks)
             for task in asyncio.as_completed(tasks):
             for task in asyncio.as_completed(tasks):
                 pending_tasks.remove(task)
                 pending_tasks.remove(task)
-                maybe_value, maybe_expiration, _ = await task
+                if not task.result() or key_id not in task.result():
+                    maybe_value, maybe_expiration, _ = task.result()[key_id]
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
                 if maybe_expiration is not None and maybe_expiration > latest_expiration:
-                    latest_value, latest_expiration = maybe_value, maybe_expiration
+                    latest_value_bytes, latest_expiration = maybe_value, maybe_expiration
                     if latest_expiration >= sufficient_expiration_time:
                     if latest_expiration >= sufficient_expiration_time:
                         break
                         break
             for task in pending_tasks:
             for task in pending_tasks:
@@ -208,19 +249,21 @@ class DHTNode:
 
 
         # step 4: we have not found entry with sufficient_expiration_time, but we may have found *something* older
         # step 4: we have not found entry with sufficient_expiration_time, but we may have found *something* older
         if should_cache and self.cache_locally:
         if should_cache and self.cache_locally:
-            self.protocol.cache.store(key_id, latest_value, latest_expiration)
+            self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
         if should_cache and self.cache_nearest:
         if should_cache and self.cache_nearest:
             num_cached_nodes = 0
             num_cached_nodes = 0
             for node_id in nearest_nodes:
             for node_id in nearest_nodes:
                 if node_id == latest_node_id:
                 if node_id == latest_node_id:
                     continue
                     continue
                 asyncio.create_task(self.protocol.call_store(
                 asyncio.create_task(self.protocol.call_store(
-                    node_to_addr[node_id], key_id, latest_value, latest_expiration, in_cache=True))
+                    node_to_addr[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
                 num_cached_nodes += 1
                 num_cached_nodes += 1
                 if num_cached_nodes >= self.cache_nearest:
                 if num_cached_nodes >= self.cache_nearest:
                     break
                     break
-
-        return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)
+        if latest_expiration != -float('inf'):
+            return self.serializer.loads(latest_value_bytes), latest_expiration
+        else:
+            return None, None
 
 
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
     async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
         """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """

+ 198 - 129
hivemind/dht/protocol.py

@@ -1,152 +1,224 @@
-import asyncio
+from __future__ import annotations
+import os
 import heapq
 import heapq
-from typing import Optional, List, Tuple, Dict, Iterator
-from rpcudp.protocol import RPCProtocol
+import asyncio
+import logging
+import urllib.parse
+from typing import Optional, List, Tuple, Dict, Iterator, Any, Sequence, Union
+from warnings import warn
+from .routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, get_dht_time
+from ..utils import Endpoint, compile_grpc
+import grpc, grpc.experimental.aio
 
 
-from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
-from ..utils import Endpoint
+with open(os.path.join(os.path.dirname(__file__), 'dht.proto'), 'r') as f_proto:
+    dht_pb2, dht_grpc = compile_grpc(f_proto.read())
 
 
 
 
-class KademliaProtocol(RPCProtocol):
-    """
-    A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
-    As a side-effect, KademliaProtocol also maintains a routing table as described in
-    https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
+class DHTProtocol(dht_grpc.DHTServicer):
+    node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
+    channel_options: Optional[Sequence[Tuple[str, Any]]]; server: grpc.experimental.aio.Server
+    storage: LocalStorage; cache: LocalStorage; routing_table: RoutingTable
 
 
-    See DHTNode (node.py) for a more detailed description.
+    @classmethod
+    async def create(cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
+                     cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+                     channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs) -> DHTProtocol:
+        """
+        A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
+        As a side-effect, DHTProtocol also maintains a routing table as described in
+        https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
 
 
-    :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
-     for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
-     Only the call_* methods are meant to be called publicly, e.g. from DHTNode
-     Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
-    """
+        See DHTNode (node.py) for a more detailed description.
 
 
-    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int, wait_timeout: float,
-                 max_concurrent_rpc: int, num_replicas: Optional[int] = None, cache_size: Optional[int] = None):
-        super().__init__(wait_timeout)
-        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas or bucket_size
-        self.rpc_semaphore = asyncio.BoundedSemaphore(value=max_concurrent_rpc)
+        :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
+         for instance, def rpc_ping can be called as protocol.call_ping(addr, dht_id) from a remote machine
+         Only the call_* methods are meant to be called publicly, e.g. from DHTNode
+         Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
+        """
+        self = cls(_initialized_with_create=True)
+        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas
+        self.wait_timeout, self.channel_options = wait_timeout, channel_options
+        self.storage, self.cache = LocalStorage(), LocalStorage(maxsize=cache_size)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
-        self.storage = LocalStorage()
-        self.cache = LocalStorage(maxsize=cache_size)
-
-    def rpc_ping(self, sender: Endpoint, sender_id_bytes: BinaryDHTID) -> BinaryDHTID:
-        """ Some dht node wants us to add it to our routing table. """
-        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
-        return bytes(self.node_id)
-
-    async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
-        """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
-        async with self.rpc_semaphore:
-            responded, response = await self.ping(recipient, bytes(self.node_id))
-        recipient_node_id = DHTID.from_bytes(response) if responded else None
-        asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-        return recipient_node_id
-
-    def rpc_store(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID,
-                  value: DHTValue, expiration_time: DHTExpiration, in_cache: bool) -> Tuple[bool, BinaryDHTID]:
-        """ Some node wants us to store this (key, value) pair """
-        asyncio.ensure_future(self.update_routing_table(DHTID.from_bytes(sender_id_bytes), sender))
-        if in_cache:
-            store_accepted = self.cache.store(DHTID.from_bytes(key_bytes), value, expiration_time)
+
+        if listen:  # set up server to process incoming rpc requests
+            grpc.experimental.aio.init_grpc_aio()
+            self.server = grpc.experimental.aio.server(**kwargs)
+            dht_grpc.add_DHTServicer_to_server(self, self.server)
+
+            found_port = self.server.add_insecure_port(listen_on)
+            assert found_port != 0, f"Failed to listen to {listen_on}"
+            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=found_port)
+            self.port = found_port
+            await self.server.start()
+        else:  # not listening to incoming requests, client-only mode
+            # note: use empty node_info so peers wont add you to their routing tables
+            self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
+            if listen_on != '0.0.0.0:*' or len(kwargs) != 0:
+                warn(f"DHTProtocol has no server (due to listen=False), listen_on"
+                     f"and kwargs have no effect (unused kwargs: {kwargs})")
+        return self
+
+    def __init__(self, *, _initialized_with_create=False):
+        """ Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances """
+        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        super().__init__()
+
+    async def shutdown(self, timeout=None):
+        """ Process existing requests, close all connections and stop the server """
+        if self.server:
+            await self.server.stop(timeout)
         else:
         else:
-            store_accepted = self.storage.store(DHTID.from_bytes(key_bytes), value, expiration_time)
-        return store_accepted, bytes(self.node_id)
+            warn("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
 
-    async def call_store(self, recipient: Endpoint, key: DHTID, value: DHTValue,
-                         expiration_time: DHTExpiration, in_cache: bool = False) -> Optional[bool]:
-        """
-        Ask a recipient to store (key, value) pair until expiration time or update their older value
+    def _get(self, peer: Endpoint) -> dht_grpc.DHTStub:
+        """ get a DHTStub that sends requests to a given peer """
+        channel = grpc.experimental.aio.insecure_channel(peer, options=self.channel_options)
+        return dht_grpc.DHTStub(channel)
 
 
-        :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
+    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
         """
         """
-        async with self.rpc_semaphore:
-            responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
-                                                   value, expiration_time, in_cache)
-        if responded:
-            store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
-            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-            return store_accepted
-        return None
-
-    def rpc_find_node(self, sender: Endpoint, sender_id_bytes: BinaryDHTID,
-                      query_id_bytes: BinaryDHTID) -> Tuple[List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        Get peer's node id and add him to the routing table. If peer doesn't respond, return None
+        :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+        :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
+
+        :return: node's DHTID, if peer responded and decided to send his node_id
         """
         """
-        Someone wants to find :key_node: in the DHT. Give him k nearest neighbors from our routing table
+        try:
+            peer_info = await self._get(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            peer_info = None
+        responded = bool(peer_info and peer_info.node_id)
+        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
+        asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
+        return peer_id
+
+    async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext):
+        """ Some node wants us to add it to our routing table. """
+        if peer_info.node_id and peer_info.rpc_port:
+            sender_id = DHTID.from_bytes(peer_info.node_id)
+            peer_url = urllib.parse.urlparse(context.peer())
+            address = peer_url.path[:peer_url.path.rindex(':')]
+            asyncio.create_task(self.update_routing_table(sender_id, f"{address}:{peer_info.rpc_port}"))
+        return self.node_info
 
 
-        :returns: a list of pairs (node_id, address) of :bucket_size: nearest to key_node according to XOR distance,
-         also returns our own node id for routing table maintenance
+    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID], values: Sequence[BinaryDHTValue],
+                         expirations: Union[DHTExpiration, Sequence[DHTExpiration]],
+                         in_cache: Optional[Union[bool, Sequence[bool]]] = None) -> Sequence[bool]:
         """
         """
-        query_id, sender_id = DHTID.from_bytes(query_id_bytes), DHTID.from_bytes(sender_id_bytes)
-        asyncio.ensure_future(self.update_routing_table(sender_id, sender))
-        peer_ids_and_addr = self.routing_table.get_nearest_neighbors(query_id, k=self.bucket_size, exclude=sender_id)
-        return [(bytes(peer_id), peer_addr) for peer_id, peer_addr in peer_ids_and_addr], bytes(self.node_id)
+        Ask a recipient to store several (key, value : expiration) items or update their older value
 
 
-    async def call_find_node(self, recipient: Endpoint, query_id: DHTID) -> Dict[DHTID, Endpoint]:
+        :param peer: request this peer to store the data
+        :param keys: a list of N keys digested by DHTID.generate(source=some_dict_key)
+        :param values: a list of N serialized values (bytes) for each respective key
+        :param expirations: a list of N expiration timestamps for each respective key-value pair (see get_dht_time())
+        :param in_cache: a list of booleans, True = store i-th key in cache, value = store i-th key locally
+        :note: the difference between storing normally and in cache is that normal storage is guaranteed to be stored
+         until expiration time (best-effort), whereas cached storage can be evicted early due to limited cache size
+
+        :return: list of [True / False] True = stored, False = failed (found newer value or no response)
+         if peer did not respond (e.g. due to timeout or congestion), returns None
         """
         """
-        Ask a recipient to give you nearest neighbors to key_node. If recipient knows key_node directly,
-         it will be returned as first of the neighbors; if recipient does not respond, return empty dict.
+        in_cache = in_cache if in_cache is not None else [False] * len(keys)  # default value (None)
+        in_cache = [in_cache] * len(keys) if isinstance(in_cache, bool) else in_cache  # single bool
+        expirations = [expirations] * len(keys) if isinstance(expirations, DHTExpiration) else expirations
+        keys, values, expirations, in_cache = map(list, [keys, values, expirations, in_cache])
+        assert len(keys) == len(values) == len(expirations) == len(in_cache), "Data is not aligned"
+        store_request = dht_pb2.StoreRequest(keys=list(map(DHTID.to_bytes, keys)), values=values,
+                                             expiration=expirations, in_cache=in_cache, peer=self.node_info)
+        try:
+            response = await self._get(peer).rpc_store(store_request, timeout=self.wait_timeout)
+            if response.peer and response.peer.node_id:
+                peer_id = DHTID.from_bytes(response.peer.node_id)
+                asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
+            return response.store_ok
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            return [False] * len(keys)
+
+    async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
+        """ Some node wants us to store this (key, value) pair """
+        if request.peer:  # if requested, add peer to the routing table
+            asyncio.create_task(self.rpc_ping(request.peer, context))
+        assert len(request.keys) == len(request.values) == len(request.expiration) == len(request.in_cache)
+        response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
+        for key_bytes, value_bytes, expiration_time, in_cache in zip(
+                request.keys, request.values, request.expiration, request.in_cache):
+            local_memory = self.cache if in_cache else self.storage
+            response.store_ok.append(local_memory.store(DHTID.from_bytes(key_bytes), value_bytes, expiration_time))
+        return response
 
 
-        :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
+    async def call_find(self, peer: Endpoint, keys: Sequence[DHTID]) -> \
+            Optional[Dict[DHTID, Tuple[Optional[BinaryDHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]]]:
         """
         """
-        async with self.rpc_semaphore:
-            responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
-        if responded:
-            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
-            # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
-            recipient_node_id = DHTID.from_bytes(response[1])
-            asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
-            return peers
-        return {}
-
-    def rpc_find_value(self, sender: Endpoint, sender_id_bytes: BinaryDHTID, key_bytes: BinaryDHTID) -> \
-            Tuple[Optional[DHTValue], Optional[DHTExpiration], List[Tuple[BinaryDHTID, Endpoint]], BinaryDHTID]:
+        Request keys from a peer. For each key, look for its (value, expiration time) locally and
+         k additional peers that are most likely to have this key (ranked by XOR distance)
+
+        :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
+         value: value stored by the recipient with that key, or None if peer doesn't have this value
+         expiration time: expiration time of the returned value, None if no value was found
+         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
+         If peer didn't respond, returns None
         """
         """
-        Someone wants to find value corresponding to key. If we have the value, return the value and its expiration time
-         Either way, return :bucket_size: nearest neighbors to that node.
+        keys = list(keys)
+        find_request = dht_pb2.FindRequest(keys=list(map(DHTID.to_bytes, keys)), peer=self.node_info)
+        try:
+            response = await self._get(peer).rpc_find(find_request, timeout=self.wait_timeout)
+            if response.peer and response.peer.node_id:
+                peer_id = DHTID.from_bytes(response.peer.node_id)
+                asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
+            assert len(response.values) == len(response.expiration) == len(response.nearest) == len(keys), \
+                "DHTProtocol: response is not aligned with keys"
+
+            output = {}  # unpack data without special NOT_FOUND_* values
+            for key, value, expiration, nearest in zip(keys, response.values, response.expiration, response.nearest):
+                value = value if value != _NOT_FOUND_VALUE else None
+                expiration = expiration if expiration != _NOT_FOUND_EXPIRATION else None
+                nearest = dict(zip(map(DHTID.from_bytes, nearest.node_ids), nearest.endpoints))
+                output[key] = (value, expiration, nearest)
+            return output
+        except grpc.experimental.aio.AioRpcError as error:
+            logging.info(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
 
 
-        :returns: (value or None if we have no value, nearest neighbors, our own dht id)
-        :note: this is a deviation from Section 2.3 of the paper, original kademlia returner EITHER value OR neighbors
+    async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
         """
         """
-        maybe_value, maybe_expiration = self.storage.get(DHTID.from_bytes(key_bytes))
-        cached_value, cached_expiration = self.cache.get(DHTID.from_bytes(key_bytes))
-        if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
-            maybe_value, maybe_expiration = cached_value, cached_expiration
-        nearest_neighbors, my_id = self.rpc_find_node(sender, sender_id_bytes, key_bytes)
-        return maybe_value, maybe_expiration, nearest_neighbors, my_id
-
-    async def call_find_value(self, recipient: Endpoint, key: DHTID) -> \
-            Tuple[Optional[DHTValue], Optional[DHTExpiration], Dict[DHTID, Endpoint]]:
+        Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
+        Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         """
         """
-        Ask a recipient to give you the value, if it has one, or nearest neighbors to your key.
+        if request.peer:  # if requested, add peer to the routing table
+            asyncio.create_task(self.rpc_ping(request.peer, context))
 
 
-        :returns: (optional value, optional expiration time, and neighbors)
-         value: whatever was the latest value stored by the recipient with that key (see DHTNode contract)
-         expiration time: expiration time of the returned value, None if no value was found
-         neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
-        :note: if no response, returns None, None, {}
-        """
-        async with self.rpc_semaphore:
-            responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
-        if responded:
-            (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
-            peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
-            asyncio.ensure_future(self.update_routing_table(recipient_id, recipient, responded=responded))
-            return value, expiration_time, peers
-        return None, None, {}
-
-    async def update_routing_table(self, node_id: Optional[DHTID], addr: Endpoint, responded=True):
+        response = dht_pb2.FindResponse(values=[], expiration=[], nearest=[], peer=self.node_info)
+        for key_id in map(DHTID.from_bytes, request.keys):
+            maybe_value, maybe_expiration = self.storage.get(key_id)
+            cached_value, cached_expiration = self.cache.get(key_id)
+            if (cached_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
+                maybe_value, maybe_expiration = cached_value, cached_expiration
+            peer_ids, endpoints = zip(*self.routing_table.get_nearest_neighbors(
+                key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)))
+
+            response.values.append(maybe_value if maybe_value is not None else _NOT_FOUND_VALUE)
+            response.expiration.append(maybe_expiration if maybe_expiration is not None else _NOT_FOUND_EXPIRATION)
+            response.nearest.append(dht_pb2.Peers(node_ids=list(map(DHTID.to_bytes, peer_ids)), endpoints=endpoints))
+        return response
+
+    async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
         """
         """
         This method is called on every incoming AND outgoing request to update the routing table
         This method is called on every incoming AND outgoing request to update the routing table
 
 
-        :param addr: sender endpoint for incoming requests, recipient endpoint for outgoing requests
+        :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
           For incoming requests, this should always be True
           For incoming requests, this should always be True
         """
         """
+        node_id = node_id if node_id is not None else self.routing_table.get_id(peer_endpoint)
         if responded:  # incoming request or outgoing request with response
         if responded:  # incoming request or outgoing request with response
             if node_id not in self.routing_table:
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 # we just met a new node, maybe we know some values that it *should* store
+                data_to_send: List[Tuple[DHTID, BinaryDHTValue, DHTExpiration]] = []
                 for key, value, expiration in list(self.storage.items()):
                 for key, value, expiration in list(self.storage.items()):
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
                     if neighbors:
                     if neighbors:
@@ -155,29 +227,26 @@ class KademliaProtocol(RPCProtocol):
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         new_node_should_store = node_id.xor_distance(key) < farthest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                         this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
-                        asyncio.create_task(self.call_store(addr, key, value, expiration))
+                        data_to_send.append((key, value, expiration))
+                if data_to_send:
+                    asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
 
 
-            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
+            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
             if maybe_node_to_ping is not None:
             if maybe_node_to_ping is not None:
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
                 asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
                 asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
 
 
-        else:  # outgoing request and peer did not respond
+        else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
                 del self.routing_table[node_id]
 
 
-    def _accept_response(self, msg_id, data, address):
-        """ Override for RPCProtocol._accept_response to handle cancelled tasks """
-        future, timeout = self._outstanding[msg_id]
-        if future.cancelled():
-            timeout.cancel()
-            del self._outstanding[msg_id]
-        else:
-            super()._accept_response(msg_id, data, address)
+
+_NOT_FOUND_VALUE, _NOT_FOUND_EXPIRATION = b'', -float('inf')  # internal values to represent that a value was not found
 
 
 
 
 class LocalStorage:
 class LocalStorage:
+    """ Local dictionary that maintains up to :maxsize: tuples of (key, value, expiration) """
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.cache_size = maxsize or float("inf")
         self.data = dict()
         self.data = dict()
@@ -192,7 +261,7 @@ class LocalStorage:
             if self.key_to_heap[key] == heap_entry:
             if self.key_to_heap[key] == heap_entry:
                 del self.data[key], self.key_to_heap[key]
                 del self.data[key], self.key_to_heap[key]
 
 
-    def store(self, key: DHTID, value: DHTValue, expiration_time: DHTExpiration) -> bool:
+    def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
         """
         """
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
         :returns: True if new value was stored, False it was rejected (current value is newer)
         :returns: True if new value was stored, False it was rejected (current value is newer)
@@ -210,14 +279,14 @@ class LocalStorage:
         self.remove_outdated()
         self.remove_outdated()
         return True
         return True
 
 
-    def get(self, key: DHTID) -> (Optional[DHTValue], Optional[DHTExpiration]):
+    def get(self, key: DHTID) -> (Optional[BinaryDHTValue], Optional[DHTExpiration]):
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
         """ Get a value corresponding to a key if that (key, value) pair was previously stored here. """
         self.remove_outdated()
         self.remove_outdated()
         if key in self.data:
         if key in self.data:
             return self.data[key]
             return self.data[key]
         return None, None
         return None, None
 
 
-    def items(self) -> Iterator[Tuple[DHTID, DHTValue, DHTExpiration]]:
+    def items(self) -> Iterator[Tuple[DHTID, BinaryDHTValue, DHTExpiration]]:
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         """ Iterate over (key, value, expiration_time) tuples stored in this storage """
         self.remove_outdated()
         self.remove_outdated()
         return ((key, value, expiration) for key, (value, expiration) in self.data.items())
         return ((key, value, expiration) for key, (value, expiration) in self.data.items())

+ 9 - 2
hivemind/dht/routing.py

@@ -41,7 +41,7 @@ class RoutingTable:
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
 
 
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
-        :note: KademliaProtocol calls this method for every incoming and outgoing request if there was a response.
+        :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
           If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
           the start of the table or remove that node and replace it with
           the start of the table or remove that node and replace it with
         """
         """
@@ -66,6 +66,12 @@ class RoutingTable:
         self.buckets[index] = first
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
         self.buckets.insert(index + 1, second)
 
 
+    def get(self, node_id: DHTID, default=None) -> Optional[Endpoint]:
+        return self[node_id] if node_id in self else default
+
+    def get_id(self, peer: Endpoint, default=None) -> Optional[DHTID]:
+        return None #TODO(jheuristic)
+
     def __getitem__(self, node_id: DHTID) -> Endpoint:
     def __getitem__(self, node_id: DHTID) -> Endpoint:
         return self.buckets[self.get_bucket_index(node_id)][node_id]
         return self.buckets[self.get_bucket_index(node_id)][node_id]
 
 
@@ -174,6 +180,7 @@ class KBucket:
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
         for uid, endpoint in self.nodes_to_addr.items():
         for uid, endpoint in self.nodes_to_addr.items():
             if uid not in self.nodes_requested_for_ping:
             if uid not in self.nodes_requested_for_ping:
+                self.nodes_requested_for_ping.add(uid)
                 return uid, endpoint
                 return uid, endpoint
 
 
     def __getitem__(self, node_id: DHTID) -> Endpoint:
     def __getitem__(self, node_id: DHTID) -> Endpoint:
@@ -272,5 +279,5 @@ class DHTID(int):
         return self.to_bytes()
         return self.to_bytes()
 
 
 
 
-DHTKey, DHTValue, DHTExpiration, BinaryDHTID = Any, Any, float, bytes  # flavour types
+DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes  # flavour types
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time
 get_dht_time = time.time  # time used by all dht functionality. You can replace this with any infrastructure-wide time

+ 5 - 5
hivemind/runtime/expert_backend.py

@@ -4,7 +4,7 @@ import torch
 from torch import nn
 from torch import nn
 
 
 from .task_pool import TaskPool
 from .task_pool import TaskPool
-from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorProto, DUMMY_BATCH_SIZE, nested_map
+from ..utils import nested_flatten, nested_pack, nested_compare, BatchTensorDescriptor, DUMMY_BATCH_SIZE, nested_map
 
 
 
 
 class ExpertBackend(nn.Module):
 class ExpertBackend(nn.Module):
@@ -33,9 +33,9 @@ class ExpertBackend(nn.Module):
     """
     """
 
 
     def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
     def __init__(self, name: str, expert: nn.Module, opt: torch.optim.Optimizer, *,
-                 args_schema: Tuple[BatchTensorProto, ...] = None,
-                 kwargs_schema: Dict[str, BatchTensorProto] = None,
-                 outputs_schema: Union[BatchTensorProto, Tuple[BatchTensorProto, ...]] = None,
+                 args_schema: Tuple[BatchTensorDescriptor, ...] = None,
+                 kwargs_schema: Dict[str, BatchTensorDescriptor] = None,
+                 outputs_schema: Union[BatchTensorDescriptor, Tuple[BatchTensorDescriptor, ...]] = None,
                  **kwargs):
                  **kwargs):
         super().__init__()
         super().__init__()
         self.expert, self.opt, self.name = expert, opt, name
         self.expert, self.opt, self.name = expert, opt, name
@@ -50,7 +50,7 @@ class ExpertBackend(nn.Module):
             dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_args = tuple(sample.make_empty(DUMMY_BATCH_SIZE) for sample in args_schema)
             dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_kwargs = {key: sample.make_empty(DUMMY_BATCH_SIZE) for key, sample in kwargs_schema.items()}
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
             dummy_outputs = self.expert(*dummy_args, **dummy_kwargs)
-            outputs_schema = nested_map(BatchTensorProto.from_tensor, dummy_outputs)
+            outputs_schema = nested_map(BatchTensorDescriptor.from_tensor, dummy_outputs)
 
 
         self.outputs_schema = outputs_schema
         self.outputs_schema = outputs_schema
         self.forward_schema = (self.args_schema, self.kwargs_schema)
         self.forward_schema = (self.args_schema, self.kwargs_schema)

+ 3 - 2
hivemind/utils/__init__.py

@@ -1,8 +1,9 @@
 from .connection import *
 from .connection import *
 from .data import *
 from .data import *
 from .nested import *
 from .nested import *
-from .proto import *
+from .tensor_descr import *
 from .serializer import *
 from .serializer import *
 from .shared_future import *
 from .shared_future import *
 from .threading import *
 from .threading import *
-from .autograd import *
+from .autograd import *
+from .grpc import *

+ 2 - 2
hivemind/utils/connection.py

@@ -3,7 +3,7 @@ from contextlib import AbstractContextManager, closing
 from typing import Tuple
 from typing import Tuple
 
 
 Hostname, Port = str, int  # flavour types
 Hostname, Port = str, int  # flavour types
-Endpoint = Tuple[Hostname, Port]  # https://networkengineering.stackexchange.com/a/9435
+Endpoint = str  # e.g. 1.2.3.4:1337 or [2a21:6с8:b192:2105]:8888, https://networkengineering.stackexchange.com/a/9435
 LOCALHOST = '127.0.0.1'
 LOCALHOST = '127.0.0.1'
 
 
 
 
@@ -13,7 +13,7 @@ class Connection(AbstractContextManager):
 
 
     __slots__ = ('conn', 'addr')
     __slots__ = ('conn', 'addr')
 
 
-    def __init__(self, conn: socket, addr: Endpoint):
+    def __init__(self, conn: socket, addr: Tuple[Hostname, Port]):
         self.conn, self.addr = conn, addr
         self.conn, self.addr = conn, addr
 
 
     @staticmethod
     @staticmethod

+ 44 - 0
hivemind/utils/grpc.py

@@ -0,0 +1,44 @@
+"""
+Utilities for running GRPC services: compile protobuf, patch legacy versions, etc
+"""
+import functools
+import os
+import sys
+import tempfile
+from typing import Tuple
+from argparse import Namespace
+import grpc_tools.protoc
+
+
+@functools.lru_cache(maxsize=None)
+def compile_grpc(proto: str, *args: str) -> Tuple[Namespace, Namespace]:
+    """
+    Compiles and loads grpc protocol defined by protobuf string
+
+    :param proto: protocol buffer code as a string, as in open('file.proto').read()
+    :param args: extra cli args for grpc_tools.protoc compiler, e.g. '-Imyincludepath'
+    :returns: messages, services protobuf
+    """
+    base_include = grpc_tools.protoc.pkg_resources.resource_filename('grpc_tools', '_proto')
+
+    with tempfile.TemporaryDirectory(prefix='compile_grpc_') as build_dir:
+        proto_path = tempfile.mktemp(prefix='grpc_', suffix='.proto', dir=build_dir)
+        with open(proto_path, 'w') as fproto:
+            fproto.write(proto)
+
+        cli_args = (
+            grpc_tools.protoc.__file__, f"-I{base_include}",
+            f"--python_out={build_dir}", f"--grpc_python_out={build_dir}",
+            f"-I{build_dir}", *args, os.path.basename(proto_path))
+        code = grpc_tools.protoc._protoc_compiler.run_main([arg.encode() for arg in cli_args])
+        if code:  # hint: if you get this error in jupyter, run in console for richer error message
+            raise ValueError(f"{' '.join(cli_args)} finished with exit code {code}")
+
+        try:
+            sys.path.append(build_dir)
+            pb2_fname = os.path.basename(proto_path)[:-len('.proto')] + '_pb2'
+            messages, services = __import__(pb2_fname, fromlist=['*']), __import__(pb2_fname + '_grpc')
+            return messages, services
+        finally:
+            if sys.path.pop() != build_dir:
+                raise ImportError("Something changed sys.path while compile_grpc was in progress.")

+ 32 - 9
hivemind/utils/serializer.py

@@ -1,41 +1,64 @@
+""" A unified interface for several common serialization methods """
 import pickle
 import pickle
 from io import BytesIO
 from io import BytesIO
 
 
 import joblib
 import joblib
 import torch
 import torch
+import umsgpack
 
 
 
 
-class JoblibSerializer:
+class SerializerBase:
+    @staticmethod
+    def dumps(obj: object) -> bytes:
+        raise NotImplementedError()
+
+    @staticmethod
+    def loads(buf: bytes) -> object:
+        raise NotImplementedError()
+
+
+class JoblibSerializer(SerializerBase):
 
 
     @staticmethod
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         s = BytesIO()
         s = BytesIO()
         joblib.dump(obj, s)
         joblib.dump(obj, s)
         return s.getvalue()
         return s.getvalue()
 
 
     @staticmethod
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return joblib.load(BytesIO(buf))
         return joblib.load(BytesIO(buf))
 
 
 
 
-class PickleSerializer:
+class PickleSerializer(SerializerBase):
     @staticmethod
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
         return pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
 
 
     @staticmethod
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return pickle.loads(buf)
         return pickle.loads(buf)
 
 
 
 
-class PytorchSerializer:
+class PytorchSerializer(SerializerBase):
 
 
     @staticmethod
     @staticmethod
-    def dumps(obj) -> bytes:
+    def dumps(obj: object) -> bytes:
         s = BytesIO()
         s = BytesIO()
         torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
         torch.save(obj, s, pickle_protocol=pickle.HIGHEST_PROTOCOL)
         return s.getvalue()
         return s.getvalue()
 
 
     @staticmethod
     @staticmethod
-    def loads(buf: bytes):
+    def loads(buf: bytes) -> object:
         return torch.load(BytesIO(buf))
         return torch.load(BytesIO(buf))
+
+
+class MSGPackSerializer(SerializerBase):
+
+    @staticmethod
+    def dumps(obj: object) -> bytes:
+        return umsgpack.dumps(obj, use_bin_type=False)
+
+    @staticmethod
+    def loads(buf: bytes) -> object:
+        return umsgpack.loads(buf, raw=False)

+ 3 - 3
hivemind/utils/proto.py → hivemind/utils/tensor_descr.py

@@ -6,12 +6,12 @@ DUMMY_BATCH_SIZE = 3  # used for dummy runs only
 
 
 
 
 @dataclass(init=True, repr=True, frozen=True)
 @dataclass(init=True, repr=True, frozen=True)
-class ProtoBase:
+class DescriptorBase:
     pass
     pass
 
 
 
 
 @dataclass(init=True, repr=True, frozen=True)
 @dataclass(init=True, repr=True, frozen=True)
-class TensorProto(ProtoBase):
+class TensorDescriptor(DescriptorBase):
     size: tuple
     size: tuple
     dtype: torch.dtype = None
     dtype: torch.dtype = None
     layout: torch.layout = torch.strided
     layout: torch.layout = torch.strided
@@ -34,7 +34,7 @@ class TensorProto(ProtoBase):
 
 
 
 
 @dataclass(repr=True, frozen=True)
 @dataclass(repr=True, frozen=True)
-class BatchTensorProto(TensorProto):
+class BatchTensorDescriptor(TensorDescriptor):
     """ torch Tensor with a variable 0-th dimension, used to describe batched data """
     """ torch Tensor with a variable 0-th dimension, used to describe batched data """
 
 
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size
     def __init__(self, *instance_size, **kwargs):  # compatibility: allow initializing with *size

+ 1 - 0
hivemind/utils/threading.py

@@ -65,3 +65,4 @@ def run_and_await_k(jobs: List[callable], k: int,
             future.cancel()
             future.cancel()
             outputs[index] = future.result() if not future.exception() else future.exception()
             outputs[index] = future.result() if not future.exception() else future.exception()
     return outputs
     return outputs
+

+ 4 - 1
requirements.txt

@@ -3,6 +3,9 @@ joblib>=0.13
 numpy>=1.17
 numpy>=1.17
 requests>=2.22.0
 requests>=2.22.0
 tqdm
 tqdm
-rpcudp>=4.0.0
 prefetch_generator>=1.0.1
 prefetch_generator>=1.0.1
 pytest
 pytest
+umsgpack
+grpcio
+grpcio-tools>=1.30.0
+aiologger>=0.5.0

+ 2 - 2
tests/benchmark_throughput.py

@@ -64,8 +64,8 @@ def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
             experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
                                                            expert=expert, opt=torch.optim.Adam(expert.parameters()),
                                                            expert=expert, opt=torch.optim.Adam(expert.parameters()),
-                                                           args_schema=(hivemind.BatchTensorProto(hid_dim),),
-                                                           outputs_schema=hivemind.BatchTensorProto(hid_dim),
+                                                           args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
+                                                           outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
                                                            max_batch_size=max_batch_size,
                                                            max_batch_size=max_batch_size,
                                                            )
                                                            )
         timestamps['created_experts'] = time.perf_counter()
         timestamps['created_experts'] = time.perf_counter()

+ 183 - 174
tests/test_dht.py

@@ -4,7 +4,6 @@ import multiprocessing as mp
 import random
 import random
 import heapq
 import heapq
 import uuid
 import uuid
-from functools import partial
 from itertools import chain
 from itertools import chain
 from typing import Optional
 from typing import Optional
 import numpy as np
 import numpy as np
@@ -13,198 +12,208 @@ import hivemind
 from typing import List, Dict
 from typing import List, Dict
 
 
 from hivemind import get_dht_time
 from hivemind import get_dht_time
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol
+from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
 from hivemind.dht.protocol import LocalStorage
 from hivemind.dht.protocol import LocalStorage
 
 
 
 
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
-                          ping: Optional[hivemind.Endpoint] = None):
-    loop = asyncio.new_event_loop()
-    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128)
-    listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
-    transport, protocol = loop.run_until_complete(listen)
+def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
+    loop = asyncio.get_event_loop()
+    protocol = loop.run_until_complete(DHTProtocol.create(
+        dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
+
+    assert protocol.port == port
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
 
 
     if ping is not None:
     if ping is not None:
         loop.run_until_complete(protocol.call_ping(ping))
         loop.run_until_complete(protocol.call_ping(ping))
     started.set()
     started.set()
-    loop.run_forever()
+    loop.run_until_complete(protocol.server.wait_for_termination())
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
 
 
 def test_kademlia_protocol():
 def test_kademlia_protocol():
-    try:
-        # create the first peer
-        peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-        peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
-        peer1_proc.start(), peer1_started.wait()
-
-        # create another peer that connects to the first peer
-        peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-        peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
-                                kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True)
-        peer2_proc.start(), peer2_started.wait()
-
-        port = hivemind.find_open_port()
-        loop = asyncio.new_event_loop()
-        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5,
-                           max_concurrent_rpc=128)
-        listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
-        transport, protocol = loop.run_until_complete(listen)
-        print(f"Self id={protocol.node_id} port={port}", flush=True)
-
-        assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id
-
-        key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
-        assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration))
-
-        # peer 1 must know about peer 2
-        nodes_found = loop.run_until_complete(
-            protocol.call_find_node(('127.0.0.1', peer1_port), key))
-        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \
-            f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}"
-
-        # peer 2 must know about peer 1
-        nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key))
-        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \
-            f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}"
-
-        recv_value, recv_expiration, recv_peers = loop.run_until_complete(
-            protocol.call_find_value(('127.0.0.1', peer1_port), key))
-        assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
-                                                                      f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
-        print(recv_peers, nodes_found)
-        assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node"
-        print("Kademlia test finished sucessfully!")
-
-    finally:
-        peer1_proc.terminate()
-        peer2_proc.terminate()
-
-
-def run_node(node_id, port, peers, status_pipe: mp.Pipe):
+    # create the first peer
+    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
+    peer1_proc.start(), peer1_started.wait()
+
+    # create another peer that connects to the first peer
+    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
+    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
+                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
+    peer2_proc.start(), peer2_started.wait()
+
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+
+        loop = asyncio.get_event_loop()
+        for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
+            protocol = loop.run_until_complete(DHTProtocol.create(
+                DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
+            print(f"Self id={protocol.node_id}", flush=True)
+
+            assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+
+            key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
+            store_ok = loop.run_until_complete(protocol.call_store(
+                f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            )
+            assert all(store_ok), "DHT rejected a trivial store"
+
+            # peer 1 must know about peer 2
+            recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+            recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
+            (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
+            assert recv_id == peer2_id and recv_endpoint == f"{LOCALHOST}:{peer2_port}", \
+                f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+
+            assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
+                f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
+
+            # peer 2 must know about peer 1, but not have a *random* nonexistent value
+            dummy_key = DHTID.generate()
+            recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
+                protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+            assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
+            (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
+            assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
+                f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+
+            # cause a non-response by querying a nonexistent peer
+            dummy_port = hivemind.find_open_port()
+            assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
+
+            if listen:
+                loop.run_until_complete(protocol.shutdown())
+            print("DHTProtocol test finished sucessfully!")
+            test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
+    peer1_proc.terminate()
+    peer2_proc.terminate()
+
+
+def run_node(node_id, peers, status_pipe: mp.Pipe):
     if asyncio.get_event_loop().is_running():
     if asyncio.get_event_loop().is_running():
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
         asyncio.get_event_loop().stop()  # if we're in jupyter, get rid of its built-in event loop
-    asyncio.set_event_loop(asyncio.new_event_loop())
-    try:
-        node = DHTNode(node_id, port, initial_peers=peers)
-        status_pipe.send('STARTED')
-        while True:
-            asyncio.get_event_loop().run_forever()
-    except BaseException as e:
-        status_pipe.send(e)  # report exception to master
-        if not isinstance(e, OSError):
-            raise e
+        asyncio.set_event_loop(asyncio.new_event_loop())
+    loop = asyncio.get_event_loop()
+    node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
+    status_pipe.send(node.port)
+    while True:
+        loop.run_forever()
 
 
 
 
 def test_dht():
 def test_dht():
     # create dht with 50 nodes + your 51-st node
     # create dht with 50 nodes + your 51-st node
     dht: Dict[Endpoint, DHTID] = {}
     dht: Dict[Endpoint, DHTID] = {}
     processes: List[mp.Process] = []
     processes: List[mp.Process] = []
-    port_fails, max_port_fails = 0, 10
 
 
-    while len(dht) < 50:
+    for i in range(50):
         node_id = DHTID.generate()
         node_id = DHTID.generate()
         peers = random.sample(dht.keys(), min(len(dht), 5))
         peers = random.sample(dht.keys(), min(len(dht), 5))
-        port = hivemind.find_open_port()
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
-        proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True)
+        proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
         proc.start()
         proc.start()
-
-        status = pipe_recv.recv()
-        if status == 'STARTED':
-            processes.append(proc)
-            dht[(LOCALHOST, port)] = node_id
-        else:
-            assert isinstance(status, BaseException)
-            proc.terminate()
-            if isinstance(status, OSError):  # port already in use. It just happens sometimes.
-                port_fails += 1
-                if port_fails > max_port_fails:
-                    raise OSError("Too many 'Address already in use' errors.")
-            else:
-                raise ValueError(f"Failed to create node due to an error {status}, see traceback above")
-
-    loop = asyncio.get_event_loop()
-    me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0)  # port=0 means os-specified port
-
-    # test 1: find self
-    nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1))
-    assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
-
-    # test 2: find others
-    for i in range(10):
-        ref_endpoint, query_id = random.choice(list(dht.items()))
-        nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1))
-        assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
-
-    # test 3: find neighbors to random nodes
-    accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
-    jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
-    all_node_ids = list(dht.values())
-
-    for i in range(100):
-        query_id = DHTID.generate()
-        k_nearest = random.randint(1, 20)
-        exclude_self = random.random() > 0.5
-        nearest = loop.run_until_complete(
-            me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
-        nearest_nodes = list(nearest)  # keys from ordered dict
-
-        assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
-        assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
-        assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
-
-        ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
-        if exclude_self and me.node_id in ref_nearest:
-            ref_nearest.remove(me.node_id)
-        if len(ref_nearest) > k_nearest:
-            ref_nearest.pop()
-
-        accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
-        accuracy_denominator += 1
-
-        jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
-        jaccard_denominator += k_nearest
-
-    accuracy = accuracy_numerator / accuracy_denominator
-    print("Top-1 accuracy:", accuracy)  # should be 98-100%
-    jaccard_index = jaccard_numerator / jaccard_denominator
-    print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
-    assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-    assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
-
-    # test 4: find all nodes
-    nearest = loop.run_until_complete(
-        me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100))
-    assert len(nearest) == len(dht) + 1
-    assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
-
-    # test 5: node without peers
-    other_node = hivemind.dht.node.DHTNode()
-    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
-    assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
-    nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
-    assert len(nearest) == 0
-
-    # test 6 store and get value
-    true_time = get_dht_time() + 1200
-    assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
-    val, expiration_time = loop.run_until_complete(me.get("mykey"))
-    assert expiration_time == true_time, "Wrong time"
-    assert val == ["Value", 10], "Wrong value"
-
-    # terminate remaining processes
+        port = pipe_recv.recv()
+        processes.append(proc)
+        dht[f"{LOCALHOST}:{port}"] = node_id
+
+    test_success = mp.Event()
+
+    def _tester():
+        # note: we run everything in a separate process to re-initialize all global states from scratch
+        # this helps us avoid undesirable side-effects when running multiple tests in sequence
+        loop = asyncio.get_event_loop()
+        me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5)))
+
+        # test 1: find self
+        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=me.node_id, k_nearest=1))
+        assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
+
+        # test 2: find others
+        for i in range(10):
+            ref_endpoint, query_id = random.choice(list(dht.items()))
+            nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=query_id, k_nearest=1))
+            assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
+
+        # test 3: find neighbors to random nodes
+        accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
+        jaccard_numerator = jaccard_denominator = 0  # jaccard similarity aka intersection over union
+        all_node_ids = list(dht.values())
+
+        for i in range(100):
+            query_id = DHTID.generate()
+            k_nearest = random.randint(1, 20)
+            exclude_self = random.random() > 0.5
+            nearest = loop.run_until_complete(
+                me.find_nearest_nodes(key_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
+            nearest_nodes = list(nearest)  # keys from ordered dict
+
+            assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
+            assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
+            assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
+
+            ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
+            if exclude_self and me.node_id in ref_nearest:
+                ref_nearest.remove(me.node_id)
+            if len(ref_nearest) > k_nearest:
+                ref_nearest.pop()
+
+            accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
+            accuracy_denominator += 1
+
+            jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
+            jaccard_denominator += k_nearest
+
+        accuracy = accuracy_numerator / accuracy_denominator
+        print("Top-1 accuracy:", accuracy)  # should be 98-100%
+        jaccard_index = jaccard_numerator / jaccard_denominator
+        print("Jaccard index (intersection over union):", jaccard_index)  # should be 95-100%
+        assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+        assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
+
+        # test 4: find all nodes
+        nearest = loop.run_until_complete(me.find_nearest_nodes(key_id=DHTID.generate(), k_nearest=len(dht) + 100))
+        assert len(nearest) == len(dht) + 1
+        assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
+
+        # test 5: node without peers
+        other_node = loop.run_until_complete(DHTNode.create())
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
+        assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
+        nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
+        assert len(nearest) == 0
+
+        # test 6 store and get value
+        true_time = get_dht_time() + 1200
+        assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
+        for node in [me, other_node]:
+            val, expiration_time = loop.run_until_complete(me.get("mykey"))
+            assert expiration_time == true_time, "Wrong time"
+            assert val == ["Value", 10], "Wrong value"
+
+        test_success.set()
+
+    tester = mp.Process(target=_tester, daemon=True)
+    tester.start()
+    tester.join()
+    assert test_success.is_set()
     for proc in processes:
     for proc in processes:
         proc.terminate()
         proc.terminate()
 
 
 
 
 def test_hivemind_dht():
 def test_hivemind_dht():
-    peers = [hivemind.dht.DHT(start=True)]
+    peers = [hivemind.DHT(start=True)]
     for i in range(10):
     for i in range(10):
-        neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))]
+        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
         peers.append(hivemind.DHT(*neighbors_i, start=True))
         peers.append(hivemind.DHT(*neighbors_i, start=True))
 
 
     you: hivemind.dht.DHT = random.choice(peers)
     you: hivemind.dht.DHT = random.choice(peers)
@@ -241,37 +250,37 @@ def test_hivemind_dht():
 
 
 def test_store():
 def test_store():
     d = LocalStorage()
     d = LocalStorage()
-    d.store("key", "val", get_dht_time() + 10)
-    assert d.get("key")[0] == "val", "Wrong value"
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 10)
+    assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
     print("Test store passed")
     print("Test store passed")
 
 
 
 
 def test_get_expired():
 def test_get_expired():
     d = LocalStorage()
     d = LocalStorage()
-    d.store("key", "val", get_dht_time() + 1)
+    d.store(DHTID.generate("key"), b"val", get_dht_time() + 1)
     time.sleep(2)
     time.sleep(2)
-    assert d.get("key") == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
 
 
 def test_get_empty():
 def test_get_empty():
     d = LocalStorage()
     d = LocalStorage()
-    assert d.get("key") == (None, None), "Expired value must be deleted"
+    assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
     print("Test get expired passed")
     print("Test get expired passed")
 
 
 
 
 def test_change_expiration_time():
 def test_change_expiration_time():
     d = LocalStorage()
     d = LocalStorage()
-    d.store("key", "val1", get_dht_time() + 2)
-    d.store("key", "val2", get_dht_time() + 200)
+    d.store(DHTID.generate("key"), b"val1", get_dht_time() + 2)
+    d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
     time.sleep(4)
     time.sleep(4)
-    assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table"
+    assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
     print("Test change expiration time passed")
     print("Test change expiration time passed")
 
 
 
 
 def test_maxsize_cache():
 def test_maxsize_cache():
     d = LocalStorage(maxsize=1)
     d = LocalStorage(maxsize=1)
-    d.store("key1", "val1", get_dht_time() + 1)
-    d.store("key2", "val2", get_dht_time() + 200)
-    assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept"
-    assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"
+    d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
+    d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
+    assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
+    assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"

+ 10 - 4
tests/test_routing.py

@@ -3,6 +3,7 @@ import heapq
 import operator
 import operator
 from itertools import chain, zip_longest
 from itertools import chain, zip_longest
 
 
+from hivemind import LOCALHOST
 from hivemind.dht.routing import RoutingTable, DHTID
 from hivemind.dht.routing import RoutingTable, DHTID
 from hivemind.utils.serializer import PickleSerializer
 from hivemind.utils.serializer import PickleSerializer
 
 
@@ -37,8 +38,8 @@ def test_routing_table_basic():
 
 
     for phony_neighbor_port in random.sample(range(10000), 100):
     for phony_neighbor_port in random.sample(range(10000), 100):
         phony_id = DHTID.generate()
         phony_id = DHTID.generate()
-        routing_table.add_or_update_node(phony_id, ('localhost', phony_neighbor_port))
-        assert routing_table[phony_id] == ('localhost', phony_neighbor_port)
+        routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
+        assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
 
 
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     for bucket in routing_table.buckets:
     for bucket in routing_table.buckets:
@@ -56,7 +57,7 @@ def test_routing_table_parameters():
         node_id = DHTID.generate()
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
         routing_table = RoutingTable(node_id, bucket_size=bucket_size, depth_modulo=modulo)
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
-            routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port))
+            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
         for bucket in routing_table.buckets:
         for bucket in routing_table.buckets:
             assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
             assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
@@ -70,8 +71,13 @@ def test_routing_table_search():
         node_id = DHTID.generate()
         node_id = DHTID.generate()
         routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
         routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
         num_added = 0
         num_added = 0
+        total_nodes = 0
+
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
-            num_added += routing_table.add_or_update_node(DHTID.generate(), ('localhost', phony_neighbor_port)) is None
+            routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
+            new_total = sum(len(bucket.nodes_to_addr) for bucket in routing_table.buckets)
+            num_added += new_total > total_nodes
+            total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
     
     
         all_active_neighbors = list(chain(
         all_active_neighbors = list(chain(

+ 3 - 3
tests/test_utils/run_server.py

@@ -61,9 +61,9 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
 
 
     sample_input = name_to_input[expert_cls](4, hidden_dim)
     sample_input = name_to_input[expert_cls](4, hidden_dim)
     if isinstance(sample_input, tuple):
     if isinstance(sample_input, tuple):
-        args_schema = tuple(hivemind.BatchTensorProto.from_tensor(arg) for arg in sample_input)
+        args_schema = tuple(hivemind.BatchTensorDescriptor.from_tensor(arg) for arg in sample_input)
     else:
     else:
-        args_schema = (hivemind.BatchTensorProto.from_tensor(sample_input),)
+        args_schema = (hivemind.BatchTensorDescriptor.from_tensor(sample_input),)
 
 
     # initialize experts
     # initialize experts
     experts = {}
     experts = {}
@@ -73,7 +73,7 @@ def make_dummy_server(interface='0.0.0.0', port=None, num_experts=1, expert_cls=
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         expert_uid = f'{expert_prefix}{UID_DELIMETER}{i + expert_offset}'
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
         experts[expert_uid] = hivemind.ExpertBackend(name=expert_uid, expert=expert, opt=opt,
                                                      args_schema=args_schema,
                                                      args_schema=args_schema,
-                                                     outputs_schema=hivemind.BatchTensorProto(hidden_dim),
+                                                     outputs_schema=hivemind.BatchTensorDescriptor(hidden_dim),
                                                      max_batch_size=max_batch_size,
                                                      max_batch_size=max_batch_size,
                                                      )
                                                      )
     # actually start server
     # actually start server