|
@@ -27,16 +27,22 @@ class RoutingTable:
|
|
|
def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
|
|
|
self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
|
|
|
self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
|
|
|
+ self.endpoint_to_uid: Dict[Endpoint, DHTID] = dict() # all nodes currently in buckets, including replacements
|
|
|
+ self.uid_to_endpoint: Dict[DHTID, Endpoint] = dict() # all nodes currently in buckets, including replacements
|
|
|
|
|
|
def get_bucket_index(self, node_id: DHTID) -> int:
|
|
|
""" Get the index of the bucket that the given node would fall into. """
|
|
|
- # TODO use binsearch aka from bisect import bisect.
|
|
|
- for index, bucket in enumerate(self.buckets):
|
|
|
- if bucket.lower <= node_id < bucket.upper:
|
|
|
- return index
|
|
|
- raise ValueError(f"Failed to get bucket for node_id={node_id}, this should not be possible.")
|
|
|
-
|
|
|
- def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
|
|
|
+ lower_index, upper_index = 0, len(self.buckets)
|
|
|
+ while upper_index - lower_index > 1:
|
|
|
+ pivot_index = (lower_index + upper_index + 1) // 2
|
|
|
+ if node_id >= self.buckets[pivot_index].lower:
|
|
|
+ lower_index = pivot_index
|
|
|
+ else: # node_id < self.buckets[pivot_index].lower
|
|
|
+ upper_index = pivot_index
|
|
|
+ assert upper_index - lower_index == 1
|
|
|
+ return lower_index
|
|
|
+
|
|
|
+ def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
|
|
|
"""
|
|
|
Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
|
|
|
|
|
@@ -47,18 +53,22 @@ class RoutingTable:
|
|
|
"""
|
|
|
bucket_index = self.get_bucket_index(node_id)
|
|
|
bucket = self.buckets[bucket_index]
|
|
|
+ store_success = bucket.add_or_update_node(node_id, endpoint)
|
|
|
|
|
|
- if bucket.add_or_update_node(node_id, addr):
|
|
|
- return # this will succeed unless the bucket is full
|
|
|
+ if node_id in bucket.nodes_to_endpoint or node_id in bucket.replacement_nodes:
|
|
|
+ # if we added node to bucket or as a replacement, throw it into lookup dicts as well
|
|
|
+ self.uid_to_endpoint[node_id] = endpoint
|
|
|
+ self.endpoint_to_uid[endpoint] = node_id
|
|
|
|
|
|
- # Per section 4.2 of paper, split if the bucket has node's own id in its range
|
|
|
- # or if bucket depth is not congruent to 0 mod $b$
|
|
|
- if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
|
|
|
- self.split_bucket(bucket_index)
|
|
|
- return self.add_or_update_node(node_id, addr)
|
|
|
+ if not store_success:
|
|
|
+ # Per section 4.2 of paper, split if the bucket has node's own id in its range
|
|
|
+ # or if bucket depth is not congruent to 0 mod $b$
|
|
|
+ if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
|
|
|
+ self.split_bucket(bucket_index)
|
|
|
+ return self.add_or_update_node(node_id, endpoint)
|
|
|
|
|
|
- # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
|
|
|
- return bucket.request_ping_node()
|
|
|
+ # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
|
|
|
+ return bucket.request_ping_node()
|
|
|
|
|
|
def split_bucket(self, index: int) -> None:
|
|
|
""" Split bucket range in two equal parts and reassign nodes to the appropriate half """
|
|
@@ -66,24 +76,29 @@ class RoutingTable:
|
|
|
self.buckets[index] = first
|
|
|
self.buckets.insert(index + 1, second)
|
|
|
|
|
|
- def get(self, node_id: DHTID, default=None) -> Optional[Endpoint]:
|
|
|
- return self[node_id] if node_id in self else default
|
|
|
-
|
|
|
- def get_id(self, peer: Endpoint, default=None) -> Optional[DHTID]:
|
|
|
- return None #TODO(jheuristic)
|
|
|
+ def get(self, *, node_id: Optional[DHTID] = None, endpoint: Optional[Endpoint] = None, default=None):
|
|
|
+ """ Find endpoint for a given DHTID or vice versa """
|
|
|
+ assert (node_id is None) != (endpoint is None), "Please specify either node_id or endpoint, but not both"
|
|
|
+ if node_id is not None:
|
|
|
+ return self.uid_to_endpoint.get(node_id, default)
|
|
|
+ else:
|
|
|
+ return self.endpoint_to_uid.get(endpoint, default)
|
|
|
|
|
|
- def __getitem__(self, node_id: DHTID) -> Endpoint:
|
|
|
- return self.buckets[self.get_bucket_index(node_id)][node_id]
|
|
|
+ def __getitem__(self, item: Union[DHTID, Endpoint]) -> Union[Endpoint, DHTID]:
|
|
|
+ """ Find endpoint for a given DHTID or vice versa """
|
|
|
+ return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
|
|
|
|
|
|
def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
|
|
|
- raise NotImplementedError("KBucket doesn't support direct item assignment. Use KBucket.try_add_node instead")
|
|
|
+ raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
|
|
|
|
|
|
- def __contains__(self, node_id: DHTID) -> bool:
|
|
|
- return node_id in self.buckets[self.get_bucket_index(node_id)]
|
|
|
+ def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
|
|
|
+ return (item in self.uid_to_endpoint) if isinstance(item, DHTID) else (item in self.endpoint_to_uid)
|
|
|
|
|
|
def __delitem__(self, node_id: DHTID):
|
|
|
- node_bucket = self.buckets[self.get_bucket_index(node_id)]
|
|
|
- del node_bucket[node_id]
|
|
|
+ del self.buckets[self.get_bucket_index(node_id)][node_id]
|
|
|
+ node_endpoint = self.uid_to_endpoint.pop(node_id)
|
|
|
+ if self.endpoint_to_uid.get(node_endpoint) == node_id:
|
|
|
+ del self.endpoint_to_uid[node_endpoint]
|
|
|
|
|
|
def get_nearest_neighbors(
|
|
|
self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
|
|
@@ -93,41 +108,45 @@ class RoutingTable:
|
|
|
:param query_id: find neighbors of this node
|
|
|
:param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
|
|
|
:param exclude: if True, results will not contain query_node_id even if it is in table
|
|
|
- :returns: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
|
|
|
-
|
|
|
- :note: this is a semi-exhaustive search of nodes that takes O(n * log k) time.
|
|
|
- One can implement a more efficient knn search using a binary skip-tree in some
|
|
|
- more elegant language such as c++ / cython / numba.
|
|
|
- Here's a sketch
|
|
|
-
|
|
|
- Preparation: construct a non-regular binary tree of depth (2 * DHTID.HASH_NBYTES)
|
|
|
- Each leaf corresponds to a binary DHTID with '0' for every left turn and '1' for right turn
|
|
|
- Each non-leaf corresponds to a certain prefix, e.g. 0010110???...???
|
|
|
- If there are no nodes under some prefix xxxY???..., the corresponding node xxx????...
|
|
|
- will only have one child.
|
|
|
- Add(node):
|
|
|
- Traverse down a tree, on i-th level go left if node_i == 0, right if node_i == 1
|
|
|
- If the corresponding node is missing, simply create it
|
|
|
- Search(query, k):
|
|
|
- Traverse the tree with a depth-first search, on i-th level go left if query_i == 0, else right
|
|
|
- If the corresponding node is missing, go the other way. Proceed until you found a leaf.
|
|
|
- This leaf is your nearest neighbor. Now add more neighbors by considering alternative paths
|
|
|
- bottom-up, i.e. if your nearest neighbor is 01011, first try 01010, then 0100x, then 011xx, ...
|
|
|
-
|
|
|
- This results in O(num_nodes * bit_length) complexity for add and search
|
|
|
- Better yet: use binary tree with skips for O(num_nodes * log(num_nodes))
|
|
|
+ :return: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
|
|
|
"""
|
|
|
- all_nodes: Iterator[Tuple[DHTID, Endpoint]] = chain(*self.buckets) # uses KBucket.__iter__
|
|
|
- nearest_nodes_with_addr: List[Tuple[DHTID, Endpoint]] = heapq.nsmallest(
|
|
|
- k + int(exclude is not None), all_nodes, lambda id_and_endpoint: query_id.xor_distance(id_and_endpoint[0]))
|
|
|
- if exclude is not None:
|
|
|
- for i, (node_i, addr_i) in enumerate(list(nearest_nodes_with_addr)):
|
|
|
- if node_i == exclude:
|
|
|
- del nearest_nodes_with_addr[i]
|
|
|
- break
|
|
|
- if len(nearest_nodes_with_addr) > k:
|
|
|
- nearest_nodes_with_addr.pop() # if excluded element is not among (k + 1) nearest, simply crop to k
|
|
|
- return nearest_nodes_with_addr
|
|
|
+ # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
|
|
|
+ candidates: List[Tuple[int, DHTID, Endpoint]] = [] # min-heap based on xor distance to query_id
|
|
|
+
|
|
|
+ # step 1: add current bucket to the candidates heap
|
|
|
+ nearest_index = self.get_bucket_index(query_id)
|
|
|
+ nearest_bucket = self.buckets[nearest_index]
|
|
|
+ for node_id, endpoint in nearest_bucket.nodes_to_endpoint.items():
|
|
|
+ heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
|
|
|
+
|
|
|
+ # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
|
|
|
+ left_index, right_index = nearest_index, nearest_index + 1 # bucket indices considered, [left, right)
|
|
|
+ current_lower, current_upper, current_depth = nearest_bucket.lower, nearest_bucket.upper, nearest_bucket.depth
|
|
|
+
|
|
|
+ while current_depth > 0 and len(candidates) < k + int(exclude is not None):
|
|
|
+ split_direction = current_lower // 2 ** (DHTID.HASH_NBYTES * 8 - current_depth) % 2
|
|
|
+ # ^-- current leaf direction from pre-leaf node, 0 = left, 1 = right
|
|
|
+ current_depth -= 1 # traverse one level closer to the root and add all child nodes to the candidates heap
|
|
|
+
|
|
|
+ if split_direction == 0: # leaf was split on the left, merge its right peer(s)
|
|
|
+ current_upper += current_upper - current_lower
|
|
|
+ while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
|
|
|
+ for node_id, endpoint in self.buckets[right_index].nodes_to_endpoint.items():
|
|
|
+ heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
|
|
|
+ right_index += 1 # note: we may need to add more than one bucket if they are on a lower depth level
|
|
|
+ assert self.buckets[right_index - 1].upper == current_upper
|
|
|
+
|
|
|
+ else: # split_direction == 1, leaf was split on the right, merge its left peer(s)
|
|
|
+ current_lower -= current_upper - current_lower
|
|
|
+ while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
|
|
|
+ left_index -= 1 # note: we may need to add more than one bucket if they are on a lower depth level
|
|
|
+ for node_id, endpoint in self.buckets[left_index].nodes_to_endpoint.items():
|
|
|
+ heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
|
|
|
+ assert self.buckets[left_index].lower == current_lower
|
|
|
+
|
|
|
+ # step 3: select k nearest vertices from candidates heap
|
|
|
+ heap_top: List[Tuple[int, DHTID, Endpoint]] = heapq.nsmallest(k + int(exclude is not None), candidates)
|
|
|
+ return [(node, endpoint) for _, node, endpoint in heap_top if node != exclude][:k]
|
|
|
|
|
|
def __repr__(self):
|
|
|
bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
|
|
@@ -143,7 +162,7 @@ class KBucket:
|
|
|
def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
|
|
|
assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
|
|
|
self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
|
|
|
- self.nodes_to_addr: Dict[DHTID, Endpoint] = {}
|
|
|
+ self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
|
|
|
self.replacement_nodes: Dict[DHTID, Endpoint] = {}
|
|
|
self.nodes_requested_for_ping: Set[DHTID] = set()
|
|
|
self.last_updated = get_dht_time()
|
|
@@ -164,11 +183,11 @@ class KBucket:
|
|
|
if node_id in self.nodes_requested_for_ping:
|
|
|
self.nodes_requested_for_ping.remove(node_id)
|
|
|
self.last_updated = get_dht_time()
|
|
|
- if node_id in self.nodes_to_addr:
|
|
|
- del self.nodes_to_addr[node_id]
|
|
|
- self.nodes_to_addr[node_id] = addr
|
|
|
- elif len(self) < self.size:
|
|
|
- self.nodes_to_addr[node_id] = addr
|
|
|
+ if node_id in self.nodes_to_endpoint:
|
|
|
+ del self.nodes_to_endpoint[node_id]
|
|
|
+ self.nodes_to_endpoint[node_id] = addr
|
|
|
+ elif len(self.nodes_to_endpoint) < self.size:
|
|
|
+ self.nodes_to_endpoint[node_id] = addr
|
|
|
else:
|
|
|
if node_id in self.replacement_nodes:
|
|
|
del self.replacement_nodes[node_id]
|
|
@@ -178,36 +197,27 @@ class KBucket:
|
|
|
|
|
|
def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
|
|
|
""" :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
|
|
|
- for uid, endpoint in self.nodes_to_addr.items():
|
|
|
+ for uid, endpoint in self.nodes_to_endpoint.items():
|
|
|
if uid not in self.nodes_requested_for_ping:
|
|
|
self.nodes_requested_for_ping.add(uid)
|
|
|
return uid, endpoint
|
|
|
|
|
|
def __getitem__(self, node_id: DHTID) -> Endpoint:
|
|
|
- return self.nodes_to_addr[node_id] if node_id in self.nodes_to_addr else self.replacement_nodes[node_id]
|
|
|
+ return self.nodes_to_endpoint[node_id] if node_id in self.nodes_to_endpoint else self.replacement_nodes[node_id]
|
|
|
|
|
|
def __delitem__(self, node_id: DHTID):
|
|
|
- if not (node_id in self.nodes_to_addr or node_id in self.replacement_nodes):
|
|
|
+ if not (node_id in self.nodes_to_endpoint or node_id in self.replacement_nodes):
|
|
|
raise KeyError(f"KBucket does not contain node id={node_id}.")
|
|
|
|
|
|
if node_id in self.replacement_nodes:
|
|
|
del self.replacement_nodes[node_id]
|
|
|
|
|
|
- if node_id in self.nodes_to_addr:
|
|
|
- del self.nodes_to_addr[node_id]
|
|
|
+ if node_id in self.nodes_to_endpoint:
|
|
|
+ del self.nodes_to_endpoint[node_id]
|
|
|
|
|
|
if self.replacement_nodes:
|
|
|
newnode_id, newnode = self.replacement_nodes.popitem()
|
|
|
- self.nodes_to_addr[newnode_id] = newnode
|
|
|
-
|
|
|
- def __contains__(self, node_id: DHTID):
|
|
|
- return node_id in self.nodes_to_addr or node_id in self.replacement_nodes
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- return len(self.nodes_to_addr)
|
|
|
-
|
|
|
- def __iter__(self):
|
|
|
- return iter(self.nodes_to_addr.items())
|
|
|
+ self.nodes_to_endpoint[newnode_id] = newnode
|
|
|
|
|
|
def split(self) -> Tuple[KBucket, KBucket]:
|
|
|
""" Split bucket over midpoint, rounded down, assign nodes to according to their id """
|
|
@@ -215,13 +225,13 @@ class KBucket:
|
|
|
assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
|
|
|
left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
|
|
|
right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
|
|
|
- for node_id, addr in chain(self.nodes_to_addr.items(), self.replacement_nodes.items()):
|
|
|
+ for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
|
|
|
bucket = left if int(node_id) <= midpoint else right
|
|
|
bucket.add_or_update_node(node_id, addr)
|
|
|
return left, right
|
|
|
|
|
|
def __repr__(self):
|
|
|
- return f"{self.__class__.__name__}({len(self.nodes_to_addr)} nodes" \
|
|
|
+ return f"{self.__class__.__name__}({len(self.nodes_to_endpoint)} nodes" \
|
|
|
f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
|
|
|
f" lower={hex(self.lower)}, upper={hex(self.upper)})"
|
|
|
|