瀏覽代碼

implement logarithmic lookup, add fast inverse lookup, make harsher tests (#60)

* implement logarithmic lookup, add fast inverse lookup, harsher tests

* use caching for replacement nodes as well

* rename nodes_to_addr -> nodes_to_endpoint (consistency) remove builtins for KBucket (they produce unintuitive results w.r.t. replacement_nodes)

* update DHTProtocol for changes in RoutingTable

* review
justheuristic 5 年之前
父節點
當前提交
7f985843c2
共有 3 個文件被更改,包括 129 次插入97 次删除
  1. 3 3
      hivemind/dht/protocol.py
  2. 95 85
      hivemind/dht/routing.py
  3. 31 9
      tests/test_routing.py

+ 3 - 3
hivemind/dht/protocol.py

@@ -146,7 +146,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             return response.store_ok
             return response.store_ok
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
-            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return [False] * len(keys)
             return [False] * len(keys)
 
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
@@ -193,7 +193,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
             return output
             return output
         except grpc.experimental.aio.AioRpcError as error:
         except grpc.experimental.aio.AioRpcError as error:
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
             logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
-            asyncio.create_task(self.update_routing_table(self.routing_table.get_id(peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
         """
         """
@@ -231,7 +231,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
           For incoming requests, this should always be True
           For incoming requests, this should always be True
         """
         """
-        node_id = node_id if node_id is not None else self.routing_table.get_id(peer_endpoint)
+        node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
         if responded:  # incoming request or outgoing request with response
         if responded:  # incoming request or outgoing request with response
             if node_id not in self.routing_table:
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
                 # we just met a new node, maybe we know some values that it *should* store

+ 95 - 85
hivemind/dht/routing.py

@@ -27,16 +27,22 @@ class RoutingTable:
     def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
     def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
         self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
         self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
         self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
         self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
+        self.endpoint_to_uid: Dict[Endpoint, DHTID] = dict()  # all nodes currently in buckets, including replacements
+        self.uid_to_endpoint: Dict[DHTID, Endpoint] = dict()  # all nodes currently in buckets, including replacements
 
 
     def get_bucket_index(self, node_id: DHTID) -> int:
     def get_bucket_index(self, node_id: DHTID) -> int:
         """ Get the index of the bucket that the given node would fall into. """
         """ Get the index of the bucket that the given node would fall into. """
-        # TODO use binsearch aka from bisect import bisect.
-        for index, bucket in enumerate(self.buckets):
-            if bucket.lower <= node_id < bucket.upper:
-                return index
-        raise ValueError(f"Failed to get bucket for node_id={node_id}, this should not be possible.")
-
-    def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
+        lower_index, upper_index = 0, len(self.buckets)
+        while upper_index - lower_index > 1:
+            pivot_index = (lower_index + upper_index + 1) // 2
+            if node_id >= self.buckets[pivot_index].lower:
+                lower_index = pivot_index
+            else:  # node_id < self.buckets[pivot_index].lower
+                upper_index = pivot_index
+        assert upper_index - lower_index == 1
+        return lower_index
+
+    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
         """
         """
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
         Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
 
 
@@ -47,18 +53,22 @@ class RoutingTable:
         """
         """
         bucket_index = self.get_bucket_index(node_id)
         bucket_index = self.get_bucket_index(node_id)
         bucket = self.buckets[bucket_index]
         bucket = self.buckets[bucket_index]
+        store_success = bucket.add_or_update_node(node_id, endpoint)
 
 
-        if bucket.add_or_update_node(node_id, addr):
-            return  # this will succeed unless the bucket is full
+        if node_id in bucket.nodes_to_endpoint or node_id in bucket.replacement_nodes:
+            # if we added node to bucket or as a replacement, throw it into lookup dicts as well
+            self.uid_to_endpoint[node_id] = endpoint
+            self.endpoint_to_uid[endpoint] = node_id
 
 
-        # Per section 4.2 of paper, split if the bucket has node's own id in its range
-        # or if bucket depth is not congruent to 0 mod $b$
-        if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
-            self.split_bucket(bucket_index)
-            return self.add_or_update_node(node_id, addr)
+        if not store_success:
+            # Per section 4.2 of paper, split if the bucket has node's own id in its range
+            # or if bucket depth is not congruent to 0 mod $b$
+            if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
+                self.split_bucket(bucket_index)
+                return self.add_or_update_node(node_id, endpoint)
 
 
-        # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
-        return bucket.request_ping_node()
+            # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
+            return bucket.request_ping_node()
 
 
     def split_bucket(self, index: int) -> None:
     def split_bucket(self, index: int) -> None:
         """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
         """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
@@ -66,24 +76,29 @@ class RoutingTable:
         self.buckets[index] = first
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
         self.buckets.insert(index + 1, second)
 
 
-    def get(self, node_id: DHTID, default=None) -> Optional[Endpoint]:
-        return self[node_id] if node_id in self else default
-
-    def get_id(self, peer: Endpoint, default=None) -> Optional[DHTID]:
-        return None #TODO(jheuristic)
+    def get(self, *, node_id: Optional[DHTID] = None, endpoint: Optional[Endpoint] = None, default=None):
+        """ Find endpoint for a given DHTID or vice versa """
+        assert (node_id is None) != (endpoint is None), "Please specify either node_id or endpoint, but not both"
+        if node_id is not None:
+            return self.uid_to_endpoint.get(node_id, default)
+        else:
+            return self.endpoint_to_uid.get(endpoint, default)
 
 
-    def __getitem__(self, node_id: DHTID) -> Endpoint:
-        return self.buckets[self.get_bucket_index(node_id)][node_id]
+    def __getitem__(self, item: Union[DHTID, Endpoint]) -> Union[Endpoint, DHTID]:
+        """ Find endpoint for a given DHTID or vice versa """
+        return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
 
 
     def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
     def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
-        raise NotImplementedError("KBucket doesn't support direct item assignment. Use KBucket.try_add_node instead")
+        raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
 
 
-    def __contains__(self, node_id: DHTID) -> bool:
-        return node_id in self.buckets[self.get_bucket_index(node_id)]
+    def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
+        return (item in self.uid_to_endpoint) if isinstance(item, DHTID) else (item in self.endpoint_to_uid)
 
 
     def __delitem__(self, node_id: DHTID):
     def __delitem__(self, node_id: DHTID):
-        node_bucket = self.buckets[self.get_bucket_index(node_id)]
-        del node_bucket[node_id]
+        del self.buckets[self.get_bucket_index(node_id)][node_id]
+        node_endpoint = self.uid_to_endpoint.pop(node_id)
+        if self.endpoint_to_uid.get(node_endpoint) == node_id:
+            del self.endpoint_to_uid[node_endpoint]
 
 
     def get_nearest_neighbors(
     def get_nearest_neighbors(
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
             self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
@@ -93,41 +108,45 @@ class RoutingTable:
         :param query_id: find neighbors of this node
         :param query_id: find neighbors of this node
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param exclude: if True, results will not contain query_node_id even if it is in table
         :param exclude: if True, results will not contain query_node_id even if it is in table
-        :returns: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
-
-        :note: this is a semi-exhaustive search of nodes that takes O(n * log k) time.
-            One can implement a more efficient knn search using a binary skip-tree in some
-            more elegant language such as c++ / cython / numba.
-            Here's a sketch
-
-            Preparation: construct a non-regular binary tree of depth (2 * DHTID.HASH_NBYTES)
-             Each leaf corresponds to a binary DHTID with '0' for every left turn and '1' for right turn
-             Each non-leaf corresponds to a certain prefix, e.g. 0010110???...???
-             If there are no nodes under some prefix xxxY???..., the corresponding node xxx????...
-             will only have one child.
-            Add(node):
-             Traverse down a tree, on i-th level go left if node_i == 0, right if node_i == 1
-             If the corresponding node is missing, simply create it
-            Search(query, k):
-             Traverse the tree with a depth-first search, on i-th level go left if query_i == 0, else right
-             If the corresponding node is missing, go the other way. Proceed until you found a leaf.
-             This leaf is your nearest neighbor. Now add more neighbors by considering alternative paths
-             bottom-up, i.e. if your nearest neighbor is 01011, first try 01010, then 0100x, then 011xx, ...
-
-            This results in O(num_nodes * bit_length) complexity for add and search
-            Better yet: use binary tree with skips for O(num_nodes * log(num_nodes))
+        :return: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
         """
         """
-        all_nodes: Iterator[Tuple[DHTID, Endpoint]] = chain(*self.buckets)  # uses KBucket.__iter__
-        nearest_nodes_with_addr: List[Tuple[DHTID, Endpoint]] = heapq.nsmallest(
-            k + int(exclude is not None), all_nodes, lambda id_and_endpoint: query_id.xor_distance(id_and_endpoint[0]))
-        if exclude is not None:
-            for i, (node_i, addr_i) in enumerate(list(nearest_nodes_with_addr)):
-                if node_i == exclude:
-                    del nearest_nodes_with_addr[i]
-                    break
-            if len(nearest_nodes_with_addr) > k:
-                nearest_nodes_with_addr.pop()  # if excluded element is not among (k + 1) nearest, simply crop to k
-        return nearest_nodes_with_addr
+        # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
+        candidates: List[Tuple[int, DHTID, Endpoint]] = []  # min-heap based on xor distance to query_id
+
+        # step 1: add current bucket to the candidates heap
+        nearest_index = self.get_bucket_index(query_id)
+        nearest_bucket = self.buckets[nearest_index]
+        for node_id, endpoint in nearest_bucket.nodes_to_endpoint.items():
+            heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+
+        # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
+        left_index, right_index = nearest_index, nearest_index + 1  # bucket indices considered, [left, right)
+        current_lower, current_upper, current_depth = nearest_bucket.lower, nearest_bucket.upper, nearest_bucket.depth
+
+        while current_depth > 0 and len(candidates) < k + int(exclude is not None):
+            split_direction = current_lower // 2 ** (DHTID.HASH_NBYTES * 8 - current_depth) % 2
+            # ^-- current leaf direction from pre-leaf node, 0 = left, 1 = right
+            current_depth -= 1  # traverse one level closer to the root and add all child nodes to the candidates heap
+
+            if split_direction == 0:  # leaf was split on the left, merge its right peer(s)
+                current_upper += current_upper - current_lower
+                while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
+                    for node_id, endpoint in self.buckets[right_index].nodes_to_endpoint.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                    right_index += 1  # note: we may need to add more than one bucket if they are on a lower depth level
+                assert self.buckets[right_index - 1].upper == current_upper
+
+            else:  # split_direction == 1, leaf was split on the right, merge its left peer(s)
+                current_lower -= current_upper - current_lower
+                while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
+                    left_index -= 1  # note: we may need to add more than one bucket if they are on a lower depth level
+                    for node_id, endpoint in self.buckets[left_index].nodes_to_endpoint.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                assert self.buckets[left_index].lower == current_lower
+
+        # step 3: select k nearest vertices from candidates heap
+        heap_top: List[Tuple[int, DHTID, Endpoint]] = heapq.nsmallest(k + int(exclude is not None), candidates)
+        return [(node, endpoint) for _, node, endpoint in heap_top if node != exclude][:k]
 
 
     def __repr__(self):
     def __repr__(self):
         bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
         bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
@@ -143,7 +162,7 @@ class KBucket:
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
-        self.nodes_to_addr: Dict[DHTID, Endpoint] = {}
+        self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
         self.replacement_nodes: Dict[DHTID, Endpoint] = {}
         self.replacement_nodes: Dict[DHTID, Endpoint] = {}
         self.nodes_requested_for_ping: Set[DHTID] = set()
         self.nodes_requested_for_ping: Set[DHTID] = set()
         self.last_updated = get_dht_time()
         self.last_updated = get_dht_time()
@@ -164,11 +183,11 @@ class KBucket:
         if node_id in self.nodes_requested_for_ping:
         if node_id in self.nodes_requested_for_ping:
             self.nodes_requested_for_ping.remove(node_id)
             self.nodes_requested_for_ping.remove(node_id)
         self.last_updated = get_dht_time()
         self.last_updated = get_dht_time()
-        if node_id in self.nodes_to_addr:
-            del self.nodes_to_addr[node_id]
-            self.nodes_to_addr[node_id] = addr
-        elif len(self) < self.size:
-            self.nodes_to_addr[node_id] = addr
+        if node_id in self.nodes_to_endpoint:
+            del self.nodes_to_endpoint[node_id]
+            self.nodes_to_endpoint[node_id] = addr
+        elif len(self.nodes_to_endpoint) < self.size:
+            self.nodes_to_endpoint[node_id] = addr
         else:
         else:
             if node_id in self.replacement_nodes:
             if node_id in self.replacement_nodes:
                 del self.replacement_nodes[node_id]
                 del self.replacement_nodes[node_id]
@@ -178,36 +197,27 @@ class KBucket:
 
 
     def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
     def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
-        for uid, endpoint in self.nodes_to_addr.items():
+        for uid, endpoint in self.nodes_to_endpoint.items():
             if uid not in self.nodes_requested_for_ping:
             if uid not in self.nodes_requested_for_ping:
                 self.nodes_requested_for_ping.add(uid)
                 self.nodes_requested_for_ping.add(uid)
                 return uid, endpoint
                 return uid, endpoint
 
 
     def __getitem__(self, node_id: DHTID) -> Endpoint:
     def __getitem__(self, node_id: DHTID) -> Endpoint:
-        return self.nodes_to_addr[node_id] if node_id in self.nodes_to_addr else self.replacement_nodes[node_id]
+        return self.nodes_to_endpoint[node_id] if node_id in self.nodes_to_endpoint else self.replacement_nodes[node_id]
 
 
     def __delitem__(self, node_id: DHTID):
     def __delitem__(self, node_id: DHTID):
-        if not (node_id in self.nodes_to_addr or node_id in self.replacement_nodes):
+        if not (node_id in self.nodes_to_endpoint or node_id in self.replacement_nodes):
             raise KeyError(f"KBucket does not contain node id={node_id}.")
             raise KeyError(f"KBucket does not contain node id={node_id}.")
 
 
         if node_id in self.replacement_nodes:
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]
             del self.replacement_nodes[node_id]
 
 
-        if node_id in self.nodes_to_addr:
-            del self.nodes_to_addr[node_id]
+        if node_id in self.nodes_to_endpoint:
+            del self.nodes_to_endpoint[node_id]
 
 
             if self.replacement_nodes:
             if self.replacement_nodes:
                 newnode_id, newnode = self.replacement_nodes.popitem()
                 newnode_id, newnode = self.replacement_nodes.popitem()
-                self.nodes_to_addr[newnode_id] = newnode
-
-    def __contains__(self, node_id: DHTID):
-        return node_id in self.nodes_to_addr or node_id in self.replacement_nodes
-
-    def __len__(self):
-        return len(self.nodes_to_addr)
-
-    def __iter__(self):
-        return iter(self.nodes_to_addr.items())
+                self.nodes_to_endpoint[newnode_id] = newnode
 
 
     def split(self) -> Tuple[KBucket, KBucket]:
     def split(self) -> Tuple[KBucket, KBucket]:
         """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
         """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
@@ -215,13 +225,13 @@ class KBucket:
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
-        for node_id, addr in chain(self.nodes_to_addr.items(), self.replacement_nodes.items()):
+        for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
             bucket = left if int(node_id) <= midpoint else right
             bucket = left if int(node_id) <= midpoint else right
             bucket.add_or_update_node(node_id, addr)
             bucket.add_or_update_node(node_id, addr)
         return left, right
         return left, right
 
 
     def __repr__(self):
     def __repr__(self):
-        return f"{self.__class__.__name__}({len(self.nodes_to_addr)} nodes" \
+        return f"{self.__class__.__name__}({len(self.nodes_to_endpoint)} nodes" \
                f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
                f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
                f" lower={hex(self.lower)}, upper={hex(self.upper)})"
                f" lower={hex(self.lower)}, upper={hex(self.upper)})"
 
 

+ 31 - 9
tests/test_routing.py

@@ -35,17 +35,40 @@ def test_ids_depth():
 def test_routing_table_basic():
 def test_routing_table_basic():
     node_id = DHTID.generate()
     node_id = DHTID.generate()
     routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
     routing_table = RoutingTable(node_id, bucket_size=20, depth_modulo=5)
+    added_nodes = []
 
 
     for phony_neighbor_port in random.sample(range(10000), 100):
     for phony_neighbor_port in random.sample(range(10000), 100):
         phony_id = DHTID.generate()
         phony_id = DHTID.generate()
         routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
         routing_table.add_or_update_node(phony_id, f'{LOCALHOST}:{phony_neighbor_port}')
+        assert phony_id in routing_table
+        assert f'{LOCALHOST}:{phony_neighbor_port}' in routing_table
         assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
         assert routing_table[phony_id] == f'{LOCALHOST}:{phony_neighbor_port}'
+        assert routing_table[f'{LOCALHOST}:{phony_neighbor_port}'] == phony_id
+        added_nodes.append(phony_id)
 
 
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     assert routing_table.buckets[0].lower == DHTID.MIN and routing_table.buckets[-1].upper == DHTID.MAX
     for bucket in routing_table.buckets:
     for bucket in routing_table.buckets:
         assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
         assert len(bucket.replacement_nodes) == 0, "There should be no replacement nodes in a table with 100 entries"
     assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
     assert 3 <= len(routing_table.buckets) <= 10, len(routing_table.buckets)
 
 
+    random_node = random.choice(added_nodes)
+    assert routing_table.get(node_id=random_node) == routing_table[random_node]
+    dummy_node = DHTID.generate()
+    assert (dummy_node not in routing_table) == (routing_table.get(node_id=dummy_node) is None)
+
+    for node in added_nodes:
+        found_bucket_index = routing_table.get_bucket_index(node)
+        for bucket_index, bucket in enumerate(routing_table.buckets):
+            if bucket.lower <= node < bucket.upper:
+                break
+        else:
+            raise ValueError("Naive search could not find bucket. Universe has gone crazy.")
+        assert bucket_index == found_bucket_index
+
+
+
+
+
 
 
 def test_routing_table_parameters():
 def test_routing_table_parameters():
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
     for (bucket_size, modulo, min_nbuckets, max_nbuckets) in [
@@ -59,7 +82,7 @@ def test_routing_table_parameters():
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
         for bucket in routing_table.buckets:
         for bucket in routing_table.buckets:
-            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_addr) <= bucket.size
+            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_endpoint) <= bucket.size
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
             f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
             f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
 
 
@@ -75,20 +98,20 @@ def test_routing_table_search():
 
 
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
-            new_total = sum(len(bucket.nodes_to_addr) for bucket in routing_table.buckets)
+            new_total = sum(len(bucket.nodes_to_endpoint) for bucket in routing_table.buckets)
             num_added += new_total > total_nodes
             num_added += new_total > total_nodes
             total_nodes = new_total
             total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
     
     
         all_active_neighbors = list(chain(
         all_active_neighbors = list(chain(
-            *(bucket.nodes_to_addr.keys() for bucket in routing_table.buckets)
+            *(bucket.nodes_to_endpoint.keys() for bucket in routing_table.buckets)
         ))
         ))
         assert lower_active <= len(all_active_neighbors) <= upper_active
         assert lower_active <= len(all_active_neighbors) <= upper_active
         assert len(all_active_neighbors) == num_added
         assert len(all_active_neighbors) == num_added
         assert num_added + num_replacements == table_size
         assert num_added + num_replacements == table_size
     
     
         # random queries
         # random queries
-        for i in range(500):
+        for i in range(1000):
             k = random.randint(1, 100)
             k = random.randint(1, 100)
             query_id = DHTID.generate()
             query_id = DHTID.generate()
             exclude = query_id if random.random() < 0.5 else None
             exclude = query_id if random.random() < 0.5 else None
@@ -96,15 +119,14 @@ def test_routing_table_search():
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
             assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
             assert all(our_addr == routing_table[our_node] for our_node, our_addr in zip(our_knn, our_addrs))
-    
+
         # queries from table
         # queries from table
-        for i in range(500):
+        for i in range(1000):
             k = random.randint(1, 100)
             k = random.randint(1, 100)
             query_id = random.choice(all_active_neighbors)
             query_id = random.choice(all_active_neighbors)
             our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
             our_knn, our_addrs = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
-            reference_knn = heapq.nsmallest(
-                k + 1, all_active_neighbors,
-                key=lambda uid: query_id.xor_distance(uid))
+
+            reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             if query_id in reference_knn:
             if query_id in reference_knn:
                 reference_knn.remove(query_id)
                 reference_knn.remove(query_id)
             assert len(our_knn) == len(reference_knn)
             assert len(our_knn) == len(reference_knn)