فهرست منبع

DHT miscellaneous (#48)

* add option to share keys new peers that *should* be sharing them (improves data availability if a lot of peers join)

* * added option to NOT refresh table (new default)
* initial dht crawl is no longer blocking

* sphinx-friendly excape char

* pass cache params to kademliaprotocol

* typo

* rpc congestion

* rpc congestion

* rpc congestion

* add max concurrent rpc

* add max concurrent rpc

* fix bug in welcome protocol: previously dht peers always considered each other "new nodes" and sent EVERYTHING to EVERYONE on each rpc call. Now DHT nodes will only request store on .store call OR when a new peer knocks on their DHT

* increase to 128 concurrent rpc

* await dht traversal in bootstrap

* await dht traversal in bootstrap

* minor comment
justheuristic 5 سال پیش
والد
کامیت
68675b255c
5فایلهای تغییر یافته به همراه83 افزوده شده و 31 حذف شده
  1. 44 17
      hivemind/dht/node.py
  2. 32 11
      hivemind/dht/protocol.py
  3. 3 0
      hivemind/dht/routing.py
  4. 1 1
      hivemind/runtime/task_pool.py
  5. 3 2
      tests/test_dht.py

+ 44 - 17
hivemind/dht/node.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import random
 from collections import OrderedDict
 from collections import OrderedDict
 from functools import partial
 from functools import partial
 from typing import Optional, Tuple, List, Dict
 from typing import Optional, Tuple, List, Dict
@@ -23,9 +24,16 @@ class DHTNode:
       Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
       Recommended value: $k$ is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
     :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
     :param num_replicas: (≈k) - number of nearest nodes that will be asked to store a given key, default = bucket_size
     :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
     :param depth_modulo: (b) - kademlia can split bucket if it contains root OR up to the nearest multiple of this value
+    :param max_concurrent_rpc: maximum number of outgoing RPC requests emitted by KademliaProtocol in parallel
+        Reduce this value if your RPC requests register no response despite the peer sending the response.
     :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
     :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
     :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
     :param staleness_timeout: a bucket is considered stale if no node from that bucket was updated in this many seconds
+        if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
     :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
     :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
+    :param cache_locally: if True, caches all values (stored or found) in a node-local cache
+    :param cache_nearest: if above 0, whenever DHTNode finds a value, it will also store (cache) this value on this many
+        nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet.
+    :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
     :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
     :param interface: provide 0.0.0.0 to operate over ipv4, :: to operate over ipv6, localhost to operate locally, etc.
 
 
     :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
     :note: Hivemind DHT is optimized to store temporary metadata that is regularly updated.
@@ -47,9 +55,9 @@ class DHTNode:
 
 
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
     def __init__(self, node_id: Optional[DHTID] = None, port: Optional[Port] = None, initial_peers: List[Endpoint] = (),
                  bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
                  bucket_size: int = 20, num_replicas: Optional[int] = None, depth_modulo: int = 5,
-                 wait_timeout: float = 5, staleness_timeout: Optional[float] = 600,
+                 max_concurrent_rpc: int = 128, wait_timeout: float = 5, staleness_timeout: Optional[float] = None,
                  bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
                  bootstrap_timeout: Optional[float] = None, cache_locally: bool = True, cache_nearest: int = 1,
-                 interface: Hostname = '0.0.0.0'):
+                 cache_size=None, interface: Hostname = '0.0.0.0'):
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
         self.port = port = port if port is not None else find_open_port()
         self.port = port = port if port is not None else find_open_port()
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
         self.num_replicas = num_replicas if num_replicas is not None else bucket_size
@@ -58,34 +66,40 @@ class DHTNode:
 
 
         # create kademlia protocol and make it listen to a port
         # create kademlia protocol and make it listen to a port
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
-        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout)
+        make_protocol = partial(KademliaProtocol, self.node_id, bucket_size, depth_modulo, wait_timeout,
+                                max_concurrent_rpc, num_replicas, cache_size)
         listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
         listener = loop.run_until_complete(loop.create_datagram_endpoint(make_protocol, local_addr=(interface, port)))
         self.transport: asyncio.Transport = listener[0]
         self.transport: asyncio.Transport = listener[0]
         self.protocol: KademliaProtocol = listener[1]
         self.protocol: KademliaProtocol = listener[1]
 
 
         if initial_peers:
         if initial_peers:
-            # bootstrap part 1: ping initial_peers, add each other to the routing table
+            # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
-            began_bootstrap_time = get_dht_time()
+            start_time = get_dht_time()
             ping_tasks = map(self.protocol.call_ping, initial_peers)
             ping_tasks = map(self.protocol.call_ping, initial_peers)
-            finished_tasks, remaining_tasks = loop.run_until_complete(
-                asyncio.wait(ping_tasks, timeout=wait_timeout, return_when=asyncio.FIRST_COMPLETED))
-            time_to_first_response = get_dht_time() - began_bootstrap_time
-            # bootstrap part 2: gather all peers who responded within bootstrap_timeout, but at least one peer
-            if remaining_tasks:
+            finished_ping_tasks, remaining_ping_tasks = loop.run_until_complete(
+                asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED))
+
+            # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
+            if remaining_ping_tasks:
                 finished_in_time, stragglers = loop.run_until_complete(
                 finished_in_time, stragglers = loop.run_until_complete(
-                    asyncio.wait(remaining_tasks, timeout=bootstrap_timeout - time_to_first_response))
+                    asyncio.wait(remaining_ping_tasks, timeout=bootstrap_timeout - get_dht_time() + start_time))
                 for straggler in stragglers:
                 for straggler in stragglers:
                     straggler.cancel()
                     straggler.cancel()
-                finished_tasks |= finished_in_time
+                finished_ping_tasks |= finished_in_time
 
 
-            peer_ids = [task.result() for task in finished_tasks if task.result() is not None]
-            if len(peer_ids) == 0 and len(initial_peers) != 0:
+            if not finished_ping_tasks:
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
                 warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
 
-            # bootstrap part 3: run beam search for my node id to add my own nearest neighbors to the routing table
-            # ... and maybe receive some values that we are meant to store (see protocol.update_routing_table)
-            loop.run_until_complete(self.find_nearest_nodes(query_id=self.node_id))
+            # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
+            # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
+            # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
+            loop.run_until_complete(asyncio.wait([loop.create_task(self.find_nearest_nodes(query_id=self.node_id)),
+                                                  asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
+                                                 return_when=asyncio.FIRST_COMPLETED))
+
+        if self.staleness_timeout is not None:
+            loop.create_task(self._refresh_routing_table(period=self.staleness_timeout))
 
 
     async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
     async def find_nearest_nodes(self, query_id: DHTID, k_nearest: Optional[int] = None,
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
                                  beam_size: Optional[int] = None, exclude_self: bool = False) -> Dict[DHTID, Endpoint]:
@@ -207,3 +221,16 @@ class DHTNode:
                     break
                     break
 
 
         return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)
         return (latest_value, latest_expiration) if latest_expiration != -float('inf') else (None, None)
+
+    async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
+        """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
+        while period is not None:  # if None run once, otherwise run forever
+            refresh_time = get_dht_time()
+            staleness_threshold = refresh_time - self.staleness_timeout
+            stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
+                             if bucket.last_updated < staleness_threshold]
+            for bucket in stale_buckets:
+                refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
+                await self.find_nearest_nodes(refresh_id)
+
+            await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))

+ 32 - 11
hivemind/dht/protocol.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 import heapq
 import heapq
-from typing import Optional, List, Tuple, Dict
+from typing import Optional, List, Tuple, Dict, Iterator
 from rpcudp.protocol import RPCProtocol
 from rpcudp.protocol import RPCProtocol
 
 
 from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
 from .routing import RoutingTable, DHTID, DHTValue, DHTExpiration, BinaryDHTID, get_dht_time
@@ -21,10 +21,11 @@ class KademliaProtocol(RPCProtocol):
      Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
      Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
     """
     """
 
 
-    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int,
-                 wait_timeout: float, cache_size: Optional[int] = None):
+    def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int, wait_timeout: float,
+                 max_concurrent_rpc: int, num_replicas: Optional[int] = None, cache_size: Optional[int] = None):
         super().__init__(wait_timeout)
         super().__init__(wait_timeout)
-        self.node_id, self.bucket_size = node_id, bucket_size
+        self.node_id, self.bucket_size, self.num_replicas = node_id, bucket_size, num_replicas or bucket_size
+        self.rpc_semaphore = asyncio.BoundedSemaphore(value=max_concurrent_rpc)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.routing_table = RoutingTable(node_id, bucket_size, depth_modulo)
         self.storage = LocalStorage()
         self.storage = LocalStorage()
         self.cache = LocalStorage(maxsize=cache_size)
         self.cache = LocalStorage(maxsize=cache_size)
@@ -36,7 +37,8 @@ class KademliaProtocol(RPCProtocol):
 
 
     async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
     async def call_ping(self, recipient: Endpoint) -> Optional[DHTID]:
         """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
         """ Get recipient's node id and add him to the routing table. If recipient doesn't respond, return None """
-        responded, response = await self.ping(recipient, bytes(self.node_id))
+        async with self.rpc_semaphore:
+            responded, response = await self.ping(recipient, bytes(self.node_id))
         recipient_node_id = DHTID.from_bytes(response) if responded else None
         recipient_node_id = DHTID.from_bytes(response) if responded else None
         asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
         asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
         return recipient_node_id
         return recipient_node_id
@@ -58,8 +60,9 @@ class KademliaProtocol(RPCProtocol):
 
 
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         :returns: True if value was accepted, False if it was rejected (recipient has newer value), None if no response
         """
         """
-        responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
-                                               value, expiration_time, in_cache)
+        async with self.rpc_semaphore:
+            responded, response = await self.store(recipient, bytes(self.node_id), bytes(key),
+                                                   value, expiration_time, in_cache)
         if responded:
         if responded:
             store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
             store_accepted, recipient_node_id = response[0], DHTID.from_bytes(response[1])
             asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
             asyncio.ensure_future(self.update_routing_table(recipient_node_id, recipient, responded=responded))
@@ -86,7 +89,8 @@ class KademliaProtocol(RPCProtocol):
 
 
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         :returns: a dicitionary[node id => address] as per Section 2.3 of the paper
         """
         """
-        responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
+        async with self.rpc_semaphore:
+            responded, response = await self.find_node(recipient, bytes(self.node_id), bytes(query_id))
         if responded:
         if responded:
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in response[0]}
             # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
             # Note: we convert addr from list to tuple here --^ because some msgpack versions convert tuples to lists
@@ -122,7 +126,8 @@ class KademliaProtocol(RPCProtocol):
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
          neighbors:  a dictionary[node id => address] as per Section 2.3 of the paper;
         :note: if no response, returns None, None, {}
         :note: if no response, returns None, None, {}
         """
         """
-        responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
+        async with self.rpc_semaphore:
+            responded, response = await self.find_value(recipient, bytes(self.node_id), bytes(key))
         if responded:
         if responded:
             (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
             (value, expiration_time, peers_bytes), recipient_id = response[:-1], DHTID.from_bytes(response[-1])
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
             peers = {DHTID.from_bytes(peer_id_bytes): tuple(addr) for peer_id_bytes, addr in peers_bytes}
@@ -140,11 +145,23 @@ class KademliaProtocol(RPCProtocol):
           For incoming requests, this should always be True
           For incoming requests, this should always be True
         """
         """
         if responded:  # incoming request or outgoing request with response
         if responded:  # incoming request or outgoing request with response
+            if node_id not in self.routing_table:
+                # we just met a new node, maybe we know some values that it *should* store
+                for key, value, expiration in list(self.storage.items()):
+                    neighbors = self.routing_table.get_nearest_neighbors(key, self.num_replicas, exclude=self.node_id)
+                    if neighbors:
+                        nearest_distance = neighbors[0][0].xor_distance(key)
+                        farthest_distance = neighbors[-1][0].xor_distance(key)
+                        new_node_should_store = node_id.xor_distance(key) < farthest_distance
+                        this_node_is_responsible = self.node_id.xor_distance(key) < nearest_distance
+                    if not neighbors or (new_node_should_store and this_node_is_responsible):
+                        asyncio.create_task(self.call_store(addr, key, value, expiration))
+
             maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
             maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, addr)
             if maybe_node_to_ping is not None:
             if maybe_node_to_ping is not None:
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
-                await self.call_ping(maybe_node_to_ping[1])  # [1]-th element is that node's endpoint
+                asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
 
 
         else:  # outgoing request and peer did not respond
         else:  # outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
             if node_id is not None and node_id in self.routing_table:
@@ -160,7 +177,6 @@ class KademliaProtocol(RPCProtocol):
             super()._accept_response(msg_id, data, address)
             super()._accept_response(msg_id, data, address)
 
 
 
 
-
 class LocalStorage:
 class LocalStorage:
     def __init__(self, maxsize: Optional[int] = None):
     def __init__(self, maxsize: Optional[int] = None):
         self.cache_size = maxsize or float("inf")
         self.cache_size = maxsize or float("inf")
@@ -200,3 +216,8 @@ class LocalStorage:
         if key in self.data:
         if key in self.data:
             return self.data[key]
             return self.data[key]
         return None, None
         return None, None
+
+    def items(self) -> Iterator[Tuple[DHTID, DHTValue, DHTExpiration]]:
+        """ Iterate over (key, value, expiration_time) tuples stored in this storage """
+        self.remove_outdated()
+        return ((key, value, expiration) for key, (value, expiration) in self.data.items())

+ 3 - 0
hivemind/dht/routing.py

@@ -193,6 +193,9 @@ class KBucket:
                 newnode_id, newnode = self.replacement_nodes.popitem()
                 newnode_id, newnode = self.replacement_nodes.popitem()
                 self.nodes_to_addr[newnode_id] = newnode
                 self.nodes_to_addr[newnode_id] = newnode
 
 
+    def __contains__(self, node_id: DHTID):
+        return node_id in self.nodes_to_addr or node_id in self.replacement_nodes
+
     def __len__(self):
     def __len__(self):
         return len(self.nodes_to_addr)
         return len(self.nodes_to_addr)
 
 

+ 1 - 1
hivemind/runtime/task_pool.py

@@ -55,7 +55,7 @@ class TaskPool(TaskPoolBase):
     to process these batches and dispatches results back to request sources. Operates as a background process.
     to process these batches and dispatches results back to request sources. Operates as a background process.
 
 
     :param process_func: function to be applied to every formed batch; called by Runtime
     :param process_func: function to be applied to every formed batch; called by Runtime
-        Note that process_func should accept only \*args Tensors and return a flat tuple of Tensors
+        Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
     :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
     :param timeout: wait for a subsequent task for at most this many seconds
     :param timeout: wait for a subsequent task for at most this many seconds

+ 3 - 2
tests/test_dht.py

@@ -20,7 +20,7 @@ from hivemind.dht.protocol import LocalStorage
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
 def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
                           ping: Optional[hivemind.Endpoint] = None):
                           ping: Optional[hivemind.Endpoint] = None):
     loop = asyncio.new_event_loop()
     loop = asyncio.new_event_loop()
-    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5)
+    protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128)
     listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
     listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
     transport, protocol = loop.run_until_complete(listen)
     transport, protocol = loop.run_until_complete(listen)
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
@@ -47,7 +47,8 @@ def test_kademlia_protocol():
 
 
         port = hivemind.find_open_port()
         port = hivemind.find_open_port()
         loop = asyncio.new_event_loop()
         loop = asyncio.new_event_loop()
-        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5)
+        protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5,
+                           max_concurrent_rpc=128)
         listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
         listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
         transport, protocol = loop.run_until_complete(listen)
         transport, protocol = loop.run_until_complete(listen)
         print(f"Self id={protocol.node_id} port={port}", flush=True)
         print(f"Self id={protocol.node_id} port={port}", flush=True)