Explorar el Código

implement logarithmic lookup, add fast inverse lookup, make harsher tests (#60)

* implement logarithmic lookup, add fast inverse lookup, harsher tests

* use caching for replacement nodes as well

* rename nodes_to_addr -> nodes_to_endpoint (consistency) remove builtins for KBucket (they produce unintuitive results w.r.t. replacement_nodes)

* update DHTProtocol for changes in RoutingTable

* review
justheuristic hace 5 años
padre
commit
7f985843c2
Se han modificado 3 ficheros con 129 adiciones y 97 borrados
  1. 3 3
      hivemind/dht/protocol.py
  2. 95 85
      hivemind/dht/routing.py
  3. 31 9
      tests/test_routing.py

+ 3 - 3
hivemind/dht/protocol.py

@@ -146,7 +146,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             return response.store_ok
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
-            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return [False] * len(keys)
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
@@ -193,7 +193,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             return output
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
-            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
         """
@@ -231,7 +231,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
           For incoming requests, this should always be True
         """
-        node_id = node_id if node_id is not None else self.routing_table.get_id(peer_endpoint)
+        node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
         if responded:  # incoming request or outgoing request with response
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store

+ 95 - 85
hivemind/dht/routing.py

@@ -27,16 +27,22 @@ class RoutingTable:
     def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
         self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
         self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
+        self.endpoint_to_uid: Dict[Endpoint, DHTID] = dict()  # all nodes currently in buckets, including replacements
+        self.uid_to_endpoint: Dict[DHTID, Endpoint] = dict()  # all nodes currently in buckets, including replacements
 
     def get_bucket_index(self, node_id: DHTID) -> int:
         """ Get the index of the bucket that the given node would fall into. """
-        # TODO use binsearch aka from bisect import bisect.
-        for index, bucket in enumerate(self.buckets):
-            if bucket.lower <= node_id < bucket.upper:
-                return index
-        raise ValueError(f"Failed to get bucket for node_id={node_id}, this should not be possible.")
-
-    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
+        lower_index, upper_index = 0, len(self.buckets)
+        while upper_index - lower_index > 1:
+            pivot_index = (lower_index + upper_index + 1) // 2
+            if node_id >= self.buckets[pivot_index].lower:
+                lower_index = pivot_index
+            else:  # node_id < self.buckets[pivot_index].lower
+                upper_index = pivot_index
+        assert upper_index - lower_index == 1
+        return lower_index
+
+    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
 
@@ -47,18 +53,22 @@ class RoutingTable:
         """
         bucket_index = self.get_bucket_index(node_id)
         bucket = self.buckets[bucket_index]
+        store_success = bucket.add_or_update_node(node_id, endpoint)
 
-        if bucket.add_or_update_node(node_id, addr):
-            return  # this will succeed unless the bucket is full
+        if node_id in bucket.nodes_to_endpoint or node_id in bucket.replacement_nodes:
+            # if we added node to bucket or as a replacement, throw it into lookup dicts as well
+            self.uid_to_endpoint[node_id] = endpoint
+            self.endpoint_to_uid[endpoint] = node_id
 
-        # Per section 4.2 of paper, split if the bucket has node's own id in its range
-        # or if bucket depth is not congruent to 0 mod $b$
-        if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
-            self.split_bucket(bucket_index)
-            return self.add_or_update_node(node_id, addr)
+        if not store_success:
+            # Per section 4.2 of paper, split if the bucket has node's own id in its range
+            # or if bucket depth is not congruent to 0 mod $b$
+            if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
+                self.split_bucket(bucket_index)
+                return self.add_or_update_node(node_id, endpoint)
 
-        # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
-        return bucket.request_ping_node()
+            # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
+            return bucket.request_ping_node()
 
     def split_bucket(self, index: int) -> None:
         """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
@@ -66,24 +76,29 @@ class RoutingTable:
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
 
-    def get(self, node_id: DHTID, default=None) -> Optional[Endpoint]:
-        return self[node_id] if node_id in self else default
-
-    def get_id(self, peer: Endpoint, default=None) -> Optional[DHTID]:
-        return None #TODO(jheuristic)
+    def get(self, *, node_id: Optional[DHTID] = None, endpoint: Optional[Endpoint] = None, default=None):
+        """ Find endpoint for a given DHTID or vice versa """
+        assert (node_id is None) != (endpoint is None), "Please specify either node_id or endpoint, but not both"
+        if node_id is not None:
+            return self.uid_to_endpoint.get(node_id, default)
+        else:
+            return self.endpoint_to_uid.get(endpoint, default)
 
-    def __getitem__(self, node_id: DHTID) -> Endpoint:
-        return self.buckets[self.get_bucket_index(node_id)][node_id]
+    def __getitem__(self, item: Union[DHTID, Endpoint]) -> Union[Endpoint, DHTID]:
+        """ Find endpoint for a given DHTID or vice versa """
+        return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
 
     def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
-        raise NotImplementedError("KBucket doesn't support direct item assignment. Use KBucket.try_add_node instead")
+        raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
 
-    def __contains__(self, node_id: DHTID) -> bool:
-        return node_id in self.buckets[self.get_bucket_index(node_id)]
+    def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
+        return (item in self.uid_to_endpoint) if isinstance(item, DHTID) else (item in self.endpoint_to_uid)
 
     def __delitem__(self, node_id: DHTID):
-        node_bucket = self.buckets[self.get_bucket_index(node_id)]
-        del node_bucket[node_id]
+        del self.buckets[self.get_bucket_index(node_id)][node_id]
+        node_endpoint = self.uid_to_endpoint.pop(node_id)
+        if self.endpoint_to_uid.get(node_endpoint) == node_id:
+            del self.endpoint_to_uid[node_endpoint]
 
     def get_nearest_neighbors(
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
@@ -93,41 +108,45 @@ class RoutingTable:
         :param query_id: find neighbors of this node
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param exclude: if True, results will not contain query_node_id even if it is in table
-        :returns: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
-
-        :note: this is a semi-exhaustive search of nodes that takes O(n * log k) time.
-            One can implement a more efficient knn search using a binary skip-tree in some
-            more elegant language such as c++ / cython / numba.
-            Here's a sketch
-
-            Preparation: construct a non-regular binary tree of depth (2 * DHTID.HASH_NBYTES)
-             Each leaf corresponds to a binary DHTID with '0' for every left turn and '1' for right turn
-             Each non-leaf corresponds to a certain prefix, e.g. 0010110???...???
-             If there are no nodes under some prefix xxxY???..., the corresponding node xxx????...
-             will only have one child.
-            Add(node):
-             Traverse down a tree, on i-th level go left if node_i == 0, right if node_i == 1
-             If the corresponding node is missing, simply create it
-            Search(query, k):
-             Traverse the tree with a depth-first search, on i-th level go left if query_i == 0, else right
-             If the corresponding node is missing, go the other way. Proceed until you found a leaf.
-             This leaf is your nearest neighbor. Now add more neighbors by considering alternative paths
-             bottom-up, i.e. if your nearest neighbor is 01011, first try 01010, then 0100x, then 011xx, ...
-
-            This results in O(num_nodes * bit_length) complexity for add and search
-            Better yet: use binary tree with skips for O(num_nodes * log(num_nodes))
+        :return: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
         """
-        all_nodes: Iterator[Tuple[DHTID, Endpoint]] = chain(*self.buckets)  # uses KBucket.__iter__
-        nearest_nodes_with_addr: List[Tuple[DHTID, Endpoint]] = heapq.nsmallest(
-            k + int(exclude is not None), all_nodes, lambda id_and_endpoint: query_id.xor_distance(id_and_endpoint[0]))
-        if exclude is not None:
-            for i, (node_i, addr_i) in enumerate(list(nearest_nodes_with_addr)):
-                if node_i == exclude:
-                    del nearest_nodes_with_addr[i]
-                    break
-            if len(nearest_nodes_with_addr) > k:
-                nearest_nodes_with_addr.pop()  # if excluded element is not among (k + 1) nearest, simply crop to k
-        return nearest_nodes_with_addr
+        # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
+        candidates: List[Tuple[int, DHTID, Endpoint]] = []  # min-heap based on xor distance to query_id
+
+        # step 1: add current bucket to the candidates heap
+        nearest_index = self.get_bucket_index(query_id)
+        nearest_bucket = self.buckets[nearest_index]
+        for node_id, endpoint in nearest_bucket.nodes_to_endpoint.items():
+            heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+
+        # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
+        left_index, right_index = nearest_index, nearest_index + 1  # bucket indices considered, [left, right)
+        current_lower, current_upper, current_depth = nearest_bucket.lower, nearest_bucket.upper, nearest_bucket.depth
+
+        while current_depth > 0 and len(candidates) < k + int(exclude is not None):
+            split_direction = current_lower // 2 ** (DHTID.HASH_NBYTES * 8 - current_depth) % 2
+            # ^-- current leaf direction from pre-leaf node, 0 = left, 1 = right
+            current_depth -= 1  # traverse one level closer to the root and add all child nodes to the candidates heap
+
+            if split_direction == 0:  # leaf was split on the left, merge its right peer(s)
+                current_upper += current_upper - current_lower
+                while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
+                    for node_id, endpoint in self.buckets[right_index].nodes_to_endpoint.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                    right_index += 1  # note: we may need to add more than one bucket if they are on a lower depth level
+                assert self.buckets[right_index - 1].upper == current_upper
+
+            else:  # split_direction == 1, leaf was split on the right, merge its left peer(s)
+                current_lower -= current_upper - current_lower
+                while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
+                    left_index -= 1  # note: we may need to add more than one bucket if they are on a lower depth level
+                    for node_id, endpoint in self.buckets[left_index].nodes_to_endpoint.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                assert self.buckets[left_index].lower == current_lower
+
+        # step 3: select k nearest vertices from candidates heap
+        heap_top: List[Tuple[int, DHTID, Endpoint]] = heapq.nsmallest(k + int(exclude is not None), candidates)
+        return [(node, endpoint) for _, node, endpoint in heap_top if node != exclude][:k]
 
     def __repr__(self):
         bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
@@ -143,7 +162,7 @@ class KBucket:
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
-        self.nodes_to_addr: Dict[DHTID, Endpoint] = {}
+        self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
         self.replacement_nodes: Dict[DHTID, Endpoint] = {}
         self.nodes_requested_for_ping: Set[DHTID] = set()
         self.last_updated = get_dht_time()
@@ -164,11 +183,11 @@ class KBucket:
         if node_id in self.nodes_requested_for_ping:
             self.nodes_requested_for_ping.remove(node_id)
         self.last_updated = get_dht_time()
-        if node_id in self.nodes_to_addr:
-            del self.nodes_to_addr[node_id]
-            self.nodes_to_addr[node_id] = addr
-        elif len(self) < self.size:
-            self.nodes_to_addr[node_id] = addr
+        if node_id in self.nodes_to_endpoint:
+            del self.nodes_to_endpoint[node_id]
+            self.nodes_to_endpoint[node_id] = addr
+        elif len(self.nodes_to_endpoint) < self.size:
+            self.nodes_to_endpoint[node_id] = addr
         else:
             if node_id in self.replacement_nodes:
                 del self.replacement_nodes[node_id]
@@ -178,36 +197,27 @@ class KBucket:
 
     def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
-        for uid, endpoint in self.nodes_to_addr.items():
+        for uid, endpoint in self.nodes_to_endpoint.items():
             if uid not in self.nodes_requested_for_ping:
                 self.nodes_requested_for_ping.add(uid)
                 return uid, endpoint
 
     def __getitem__(self, node_id: DHTID) -> Endpoint:
-        return self.nodes_to_addr[node_id] if node_id in self.nodes_to_addr else self.replacement_nodes[node_id]
+        return self.nodes_to_endpoint[node_id] if node_id in self.nodes_to_endpoint else self.replacement_nodes[node_id]
 
     def __delitem__(self, node_id: DHTID):
-        if not (node_id in self.nodes_to_addr or node_id in self.replacement_nodes):
+        if not (node_id in self.nodes_to_endpoint or node_id in self.replacement_nodes):
             raise KeyError(f"KBucket does not contain node id={node_id}.")
 
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]
 
-        if node_id in self.nodes_to_addr:
-            del self.nodes_to_addr[node_id]
+        if node_id in self.nodes_to_endpoint:
+            del self.nodes_to_endpoint[node_id]
 
             if self.replacement_nodes:
                 newnode_id, newnode = self.replacement_nodes.popitem()
-                self.nodes_to_addr[newnode_id] = newnode
-
-    def __contains__(self, node_id: DHTID):
-        return node_id in self.nodes_to_addr or node_id in self.replacement_nodes
-
-    def __len__(self):
-        return len(self.nodes_to_addr)
-
-    def __iter__(self):
-        return iter(self.nodes_to_addr.items())
+                self.nodes_to_endpoint[newnode_id] = newnode
 
     def split(self) -> Tuple[KBucket, KBucket]:
         """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
@@ -215,13 +225,13 @@ class KBucket:
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
-        for node_id, addr in chain(self.nodes_to_addr.items(), self.replacement_nodes.items()):
+        for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
             bucket = left if int(node_id) <= midpoint else right
             bucket.add_or_update_node(node_id, addr)
         return left, right
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({len(self.nodes_to_addr)} nodes" \
+        return f"{self.__class__.__name__}({len(self.nodes_to_endpoint)} nodes" \
                f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
                f" lower={hex(self.lower)}, upper={hex(self.upper)})"
 

+ 31 - 9
tests/test_routing.py

@@ -35,17 +35,40 @@ def test_ids_depth():
 def test_routing_table_basic():
     node_id = DHTID.generate()
     routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
+    added_nodes = []
 
     for phony_neighbor_port in random.sample(range(10000), 100):
         phony_id = DHTID.generate()
         routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
+        assert phony_id in routing_table
+        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
         assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
+        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
+        added_nodes.append(phony_id)
 
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     for bucket in routing_table.buckets:
         assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
     assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
 
+    random_node = random.choice(added_nodes)
+    assert routing_table.get(node_id=random_node) == routing_table[random_node]
+    dummy_node = DHTID.generate()
+    assert (dummy_node not in routing_table) == (routing_table.get(node_id=dummy_node) is None)
+
+    for node in added_nodes:
+        found_bucket_index = routing_table.get_bucket_index(node)
+        for bucket_index, bucket in enumerate(routing_table.buckets):
+            if bucket.lower <= node < bucket.upper:
+                break
+        else:
+            raise ValueError("Naive search could not find bucket. Universe has gone crazy.")
+        assert bucket_index == found_bucket_index
+
+
+
+
+
 
 def test_routing_table_parameters():
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
@@ -59,7 +82,7 @@ def test_routing_table_parameters():
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
         for bucket in routing_table.buckets:
-            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
+            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_endpoint) <= bucket.size
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
             f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
 
@@ -75,20 +98,20 @@ def test_routing_table_search():
 
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
-            new_total = sum(len(bucket.nodes_to_addr) for bucket in routing_table.buckets)
+            new_total = sum(len(bucket.nodes_to_endpoint) for bucket in routing_table.buckets)
             num_added += new_total > total_nodes
             total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
     
         all_active_neighbors = list(chain(
-            *(bucket.nodes_to_addr.keys() for bucket in routing_table.buckets)
+            *(bucket.nodes_to_endpoint.keys() for bucket in routing_table.buckets)
         ))
         assert lower_active <= len(all_active_neighbors) <= upper_active
         assert len(all_active_neighbors) == num_added
         assert num_added + num_replacements == table_size
     
         # random queries
-        for i in range(500):
+        for i in range(1000):
             k = random.randint(1, 100)
             query_id = DHTID.generate()
             exclude = query_id if random.random() < 0.5 else None
@@ -96,15 +119,14 @@ def test_routing_table_search():
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
             assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
-    
+
         # queries from table
-        for i in range(500):
+        for i in range(1000):
             k = random.randint(1, 100)
             query_id = random.choice(all_active_neighbors)
             our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
-            reference_knn = heapq.nsmallest(
-                k + 1, all_active_neighbors,
-                key=lambda uid: query_id.xor_distance(uid))
+
+            reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             if query_id in reference_knn:
                 reference_knn.remove(query_id)
             assert len(our_knn) == len(reference_knn)