瀏覽代碼

DHT miscellaneous (#48)

* add option to share keys new peers that *should* be sharing them (improves data availability if a lot of peers join)

* * added option to NOT refresh table (new default)
* initial dht crawl is no longer blocking

* sphinx-friendly excape char

* pass cache params to kademliaprotocol

* typo

* rpc congestion

* rpc congestion

* rpc congestion

* add max concurrent rpc

* add max concurrent rpc

* fix bug in welcome protocol: previously dht peers always considered each other "new nodes" and sent EVERYTHING to EVERYONE on each rpc call. Now DHT nodes will only request store on .store call OR when a new peer knocks on their DHT

* increase to 128 concurrent rpc

* await dht traversal in bootstrap

* await dht traversal in bootstrap

* minor comment
justheuristic 5 年之前
父節點
當前提交
68675b255c
共有 5 個文件被更改,包括 83 次插入31 次删除
  1. 44 17
      hivemind/dht/node.py
  2. 32 11
      hivemind/dht/protocol.py
  3. 3 0
      hivemind/dht/routing.py
  4. 1 1
      hivemind/runtime/task_pool.py
  5. 3 2
      tests/test_dht.py

+ 44 - 17
hivemind/dht/node.py

@@ -1,4 +1,5 @@
 import asyncio
+import random
 from collections import OrderedDict
 from functools import partial
 from typing import Optional, Tuple, List, Dict
@@ -23,9 +24,16 @@ class DHTNode:
       Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
     :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
     :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
+    :param max_concurrent_rpc: maximum number of outgoing RPC requests emitted by KademliaProtocol in parallel
+        Reduce this value if your RPC requests register no response despite the peer sending the response.
     :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
     :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
+        if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
     :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+    :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+    :param cache_nearest: if above 0, whenever DHTNode finds a value, it will also store (cache) this value on this many
+        nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet.
+    :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
     :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
 
     :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
@@ -47,9 +55,9 @@ class DHTNode:
 
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
                  bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
-                 wait_timeout: float = 5, staleness_timeout: Optional[float] = 600,
+                 max_concurrent_rpc: int = 128, wait_timeout: float = 5, staleness_timeout: Optional[float] = None,
                  bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
-                 interface: Hostname = '0.0.0.0'):
+                 cache_size=None, interface: Hostname = '0.0.0.0'):
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.port = port = port if port is not None else find_open_port()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
@@ -58,34 +66,40 @@ class DHTNode:
 
         # create kademlia protocol and make it listen to a port
         loop = asyncio.get_event_loop()
-        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout)
+        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout,
+                                max_concurrent_rpc, num_replicas, cache_size)
         listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
         self.transport: asyncio.Transport = listener[0]
         self.protocol: KademliaProtocol = listener[1]
 
         if initial_peers:
-            # bootstrap part 1: ping initial_peers, add each other to the routing table
+            # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
-            began_bootstrap_time = get_dht_time()
+            start_time = get_dht_time()
             ping_tasks = map(self.protocol.call_ping, initial_peers)
-            finished_tasks, remaining_tasks = loop.run_until_complete(
-                asyncio.wait(ping_tasks, timeout=wait_timeout, return_when=asyncio.FIRST_COMPLETED))
-            time_to_first_response = get_dht_time() - began_bootstrap_time
-            # bootstrap part 2: gather all peers who responded within bootstrap_timeout, but at least one peer
-            if remaining_tasks:
+            finished_ping_tasks, remaining_ping_tasks = loop.run_until_complete(
+                asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED))
+
+            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
+            if remaining_ping_tasks:
                 finished_in_time, stragglers = loop.run_until_complete(
-                    asyncio.wait(remaining_tasks, timeout=bootstrap_timeout - time_to_first_response))
+                    asyncio.wait(remaining_ping_tasks, timeout=bootstrap_timeout - get_dht_time() + start_time))
                 for straggler in stragglers:
                     straggler.cancel()
-                finished_tasks |= finished_in_time
+                finished_ping_tasks |= finished_in_time
 
-            peer_ids = [task.result() for task in finished_tasks if task.result() is not None]
-            if len(peer_ids) == 0 and len(initial_peers) != 0:
+            if not finished_ping_tasks:
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
-            # bootstrap part 3: run beam search for my node id to add my own nearest neighbors to the routing table
-            # ... and maybe receive some values that we are meant to store (see protocol.update_routing_table)
-            loop.run_until_complete(self.find_nearest_nodes(query_id=self.node_id))
+            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
+            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
+            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
+            loop.run_until_complete(asyncio.wait([loop.create_task(self.find_nearest_nodes(query_id=self.node_id)),
+                                                  asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
+                                                 return_when=asyncio.FIRST_COMPLETED))
+
+        if self.staleness_timeout is not None:
+            loop.create_task(self._refresh_routing_table(period=self.staleness_timeout))
 
     async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
@@ -207,3 +221,16 @@ class DHTNode:
                     break
 
         return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)
+
+    async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
+        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
+        while period is not None:  # if None run once, otherwise run forever
+            refresh_time = get_dht_time()
+            staleness_threshold = refresh_time - self.staleness_timeout
+            stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
+                             if bucket.last_updated < staleness_threshold]
+            for bucket in stale_buckets:
+                refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
+                await self.find_nearest_nodes(refresh_id)
+
+            await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))

+ 32 - 11
hivemind/dht/protocol.py

@@ -1,6 +1,6 @@
 import asyncio
 import heapq
-from typing import Optional, List, Tuple, Dict
+from typing import Optional, List, Tuple, Dict, Iterator
 from rpcudp.protocol import RPCProtocol
 
 from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
@@ -21,10 +21,11 @@ class KademliaProtocol(RPCProtocol):
      Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
     """
 
-    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int,
-                 wait_timeout: float, cache_size: Optional[int] = None):
+    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int, wait_timeout: float,
+                 max_concurrent_rpc: int, num_replicas: Optional[int] = None, cache_size: Optional[int] = None):
         super().__init__(wait_timeout)
-        self.node_id, self.bucket_size = node_id, bucket_size
+        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas or bucket_size
+        self.rpc_semaphore = asyncio.BoundedSemaphore(value=max_concurrent_rpc)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.storage = LocalStorage()
         self.cache = LocalStorage(maxsize=cache_size)
@@ -36,7 +37,8 @@ class KademliaProtocol(RPCProtocol):
 
     async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
         """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
-        responded, response = await self.ping(recipient, bytes(self.node_id))
+        async with self.rpc_semaphore:
+            responded, response = await self.ping(recipient, bytes(self.node_id))
         recipient_node_id = DHTID.from_bytes(response) if responded else None
         asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
         return recipient_node_id
@@ -58,8 +60,9 @@ class KademliaProtocol(RPCProtocol):
 
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         """
-        responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
-                                               value, expiration_time, in_cache)
+        async with self.rpc_semaphore:
+            responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
+                                                   value, expiration_time, in_cache)
         if responded:
             store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
             asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
@@ -86,7 +89,8 @@ class KademliaProtocol(RPCProtocol):
 
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         """
-        responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
+        async with self.rpc_semaphore:
+            responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
         if responded:
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
             # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
@@ -122,7 +126,8 @@ class KademliaProtocol(RPCProtocol):
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
         :note: if no response, returns None, None, {}
         """
-        responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
+        async with self.rpc_semaphore:
+            responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
         if responded:
             (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
@@ -140,11 +145,23 @@ class KademliaProtocol(RPCProtocol):
           For incoming requests, this should always be True
         """
         if responded:  # incoming request or outgoing request with response
+            if node_id not in self.routing_table:
+                # we just met a new node, maybe we know some values that it *should* store
+                for key, value, expiration in list(self.storage.items()):
+                    neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
+                    if neighbors:
+                        nearest_distance = neighbors[0][0].xor_distance(key)
+                        farthest_distance = neighbors[-1][0].xor_distance(key)
+                        new_node_should_store = node_id.xor_distance(key) < farthest_distance
+                        this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
+                    if not neighbors or (new_node_should_store and this_node_is_responsible):
+                        asyncio.create_task(self.call_store(addr, key, value, expiration))
+
             maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
             if maybe_node_to_ping is not None:
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
-                await self.call_ping(maybe_node_to_ping[1])  # [1]-th element is that node's endpoint
+                asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
 
         else:  # outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
@@ -160,7 +177,6 @@ class KademliaProtocol(RPCProtocol):
             super()._accept_response(msg_id, data, address)
 
 
-
 class LocalStorage:
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
@@ -200,3 +216,8 @@ class LocalStorage:
         if key in self.data:
             return self.data[key]
         return None, None
+
+    def items(self) -> Iterator[Tuple[DHTID, DHTValue, DHTExpiration]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self.remove_outdated()
+        return ((key, value, expiration) for key, (value, expiration) in self.data.items())

+ 3 - 0
hivemind/dht/routing.py

@@ -193,6 +193,9 @@ class KBucket:
                 newnode_id, newnode = self.replacement_nodes.popitem()
                 self.nodes_to_addr[newnode_id] = newnode
 
+    def __contains__(self, node_id: DHTID):
+        return node_id in self.nodes_to_addr or node_id in self.replacement_nodes
+
     def __len__(self):
         return len(self.nodes_to_addr)
 

+ 1 - 1
hivemind/runtime/task_pool.py

@@ -55,7 +55,7 @@ class TaskPool(TaskPoolBase):
     to process these batches and dispatches results back to request sources. Operates as a background process.
 
     :param process_func: function to be applied to every formed batch; called by Runtime
-        Note that process_func should accept only \*args Tensors and return a flat tuple of Tensors
+        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param timeout: wait for a subsequent task for at most this many seconds

+ 3 - 2
tests/test_dht.py

@@ -20,7 +20,7 @@ from hivemind.dht.protocol import LocalStorage
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
                           ping: Optional[hivemind.Endpoint] = None):
     loop = asyncio.new_event_loop()
-    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5)
+    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128)
     listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
     transport, protocol = loop.run_until_complete(listen)
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
@@ -47,7 +47,8 @@ def test_kademlia_protocol():
 
         port = hivemind.find_open_port()
         loop = asyncio.new_event_loop()
-        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5)
+        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5,
+                           max_concurrent_rpc=128)
         listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
         transport, protocol = loop.run_until_complete(listen)
         print(f"Self id={protocol.node_id} port={port}", flush=True)