routing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from __future__ import annotations
  2. import hashlib
  3. import os
  4. import random
  5. import time
  6. import heapq
  7. from collections.abc import Iterable
  8. from itertools import chain
  9. from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence, Iterator
  10. from ..utils import Endpoint, PickleSerializer
  11. class RoutingTable:
  12. """
  13. A data structure that contains DHT peers bucketed according to their distance to node_id
  14. :param node_id: node id used to measure distance
  15. :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
  16. :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
  17. :note: you can find a more detailed docstring for Node class, see node.py
  18. :note: kademlia paper refers to https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
  19. """
  20. def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
  21. self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
  22. self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
  23. self.endpoint_to_uid: Dict[Endpoint, DHTID] = dict() # all nodes currently in buckets, including replacements
  24. self.uid_to_endpoint: Dict[DHTID, Endpoint] = dict() # all nodes currently in buckets, including replacements
  25. def get_bucket_index(self, node_id: DHTID) -> int:
  26. """ Get the index of the bucket that the given node would fall into. """
  27. lower_index, upper_index = 0, len(self.buckets)
  28. while upper_index - lower_index > 1:
  29. pivot_index = (lower_index + upper_index + 1) // 2
  30. if node_id >= self.buckets[pivot_index].lower:
  31. lower_index = pivot_index
  32. else: # node_id < self.buckets[pivot_index].lower
  33. upper_index = pivot_index
  34. assert upper_index - lower_index == 1
  35. return lower_index
  36. def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
  37. """
  38. Update routing table after an incoming request from :addr: (host:port) or outgoing request to :addr:
  39. :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
  40. :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
  41. If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
  42. the start of the table or remove that node and replace it with
  43. """
  44. bucket_index = self.get_bucket_index(node_id)
  45. bucket = self.buckets[bucket_index]
  46. store_success = bucket.add_or_update_node(node_id, endpoint)
  47. if node_id in bucket.nodes_to_endpoint or node_id in bucket.replacement_nodes:
  48. # if we added node to bucket or as a replacement, throw it into lookup dicts as well
  49. self.uid_to_endpoint[node_id] = endpoint
  50. self.endpoint_to_uid[endpoint] = node_id
  51. if not store_success:
  52. # Per section 4.2 of paper, split if the bucket has node's own id in its range
  53. # or if bucket depth is not congruent to 0 mod $b$
  54. if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
  55. self.split_bucket(bucket_index)
  56. return self.add_or_update_node(node_id, endpoint)
  57. # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
  58. return bucket.request_ping_node()
  59. def split_bucket(self, index: int) -> None:
  60. """ Split bucket range in two equal parts and reassign nodes to the appropriate half """
  61. first, second = self.buckets[index].split()
  62. self.buckets[index] = first
  63. self.buckets.insert(index + 1, second)
  64. def get(self, *, node_id: Optional[DHTID] = None, endpoint: Optional[Endpoint] = None, default=None):
  65. """ Find endpoint for a given DHTID or vice versa """
  66. assert (node_id is None) != (endpoint is None), "Please specify either node_id or endpoint, but not both"
  67. if node_id is not None:
  68. return self.uid_to_endpoint.get(node_id, default)
  69. else:
  70. return self.endpoint_to_uid.get(endpoint, default)
  71. def __getitem__(self, item: Union[DHTID, Endpoint]) -> Union[Endpoint, DHTID]:
  72. """ Find endpoint for a given DHTID or vice versa """
  73. return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
  74. def __setitem__(self, node_id: DHTID, addr: Endpoint) -> NotImplementedError:
  75. raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
  76. def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
  77. return (item in self.uid_to_endpoint) if isinstance(item, DHTID) else (item in self.endpoint_to_uid)
  78. def __delitem__(self, node_id: DHTID):
  79. del self.buckets[self.get_bucket_index(node_id)][node_id]
  80. node_endpoint = self.uid_to_endpoint.pop(node_id)
  81. if self.endpoint_to_uid.get(node_endpoint) == node_id:
  82. del self.endpoint_to_uid[node_endpoint]
  83. def get_nearest_neighbors(
  84. self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
  85. """
  86. Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
  87. :param query_id: find neighbors of this node
  88. :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
  89. :param exclude: if True, results will not contain query_node_id even if it is in table
  90. :return: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
  91. """
  92. # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
  93. candidates: List[Tuple[int, DHTID, Endpoint]] = [] # min-heap based on xor distance to query_id
  94. # step 1: add current bucket to the candidates heap
  95. nearest_index = self.get_bucket_index(query_id)
  96. nearest_bucket = self.buckets[nearest_index]
  97. for node_id, endpoint in nearest_bucket.nodes_to_endpoint.items():
  98. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
  99. # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
  100. left_index, right_index = nearest_index, nearest_index + 1 # bucket indices considered, [left, right)
  101. current_lower, current_upper, current_depth = nearest_bucket.lower, nearest_bucket.upper, nearest_bucket.depth
  102. while current_depth > 0 and len(candidates) < k + int(exclude is not None):
  103. split_direction = current_lower // 2 ** (DHTID.HASH_NBYTES * 8 - current_depth) % 2
  104. # ^-- current leaf direction from pre-leaf node, 0 = left, 1 = right
  105. current_depth -= 1 # traverse one level closer to the root and add all child nodes to the candidates heap
  106. if split_direction == 0: # leaf was split on the left, merge its right peer(s)
  107. current_upper += current_upper - current_lower
  108. while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
  109. for node_id, endpoint in self.buckets[right_index].nodes_to_endpoint.items():
  110. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
  111. right_index += 1 # note: we may need to add more than one bucket if they are on a lower depth level
  112. assert self.buckets[right_index - 1].upper == current_upper
  113. else: # split_direction == 1, leaf was split on the right, merge its left peer(s)
  114. current_lower -= current_upper - current_lower
  115. while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
  116. left_index -= 1 # note: we may need to add more than one bucket if they are on a lower depth level
  117. for node_id, endpoint in self.buckets[left_index].nodes_to_endpoint.items():
  118. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
  119. assert self.buckets[left_index].lower == current_lower
  120. # step 3: select k nearest vertices from candidates heap
  121. heap_top: List[Tuple[int, DHTID, Endpoint]] = heapq.nsmallest(k + int(exclude is not None), candidates)
  122. return [(node, endpoint) for _, node, endpoint in heap_top if node != exclude][:k]
  123. def __repr__(self):
  124. bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
  125. return f"{self.__class__.__name__}(node_id={self.node_id}, bucket_size={self.bucket_size}," \
  126. f" modulo={self.depth_modulo},\nbuckets=[\n{bucket_info})"
  127. class KBucket:
  128. """
  129. A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
  130. Maps DHT node ids to their endpoints (hostname, addr)
  131. """
  132. def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
  133. assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
  134. self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
  135. self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
  136. self.replacement_nodes: Dict[DHTID, Endpoint] = {}
  137. self.nodes_requested_for_ping: Set[DHTID] = set()
  138. self.last_updated = get_dht_time()
  139. def has_in_range(self, node_id: DHTID):
  140. """ Check if node_id is between this bucket's lower and upper bounds """
  141. return self.lower <= node_id < self.upper
  142. def add_or_update_node(self, node_id: DHTID, addr: Endpoint) -> bool:
  143. """
  144. Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
  145. If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
  146. :param node_id: dht node identifier that should be added or moved to the front of bucket
  147. :param addr: a pair of (hostname, port) associated with that node id
  148. :note: this function has a side-effect of resetting KBucket.last_updated time
  149. """
  150. if node_id in self.nodes_requested_for_ping:
  151. self.nodes_requested_for_ping.remove(node_id)
  152. self.last_updated = get_dht_time()
  153. if node_id in self.nodes_to_endpoint:
  154. del self.nodes_to_endpoint[node_id]
  155. self.nodes_to_endpoint[node_id] = addr
  156. elif len(self.nodes_to_endpoint) < self.size:
  157. self.nodes_to_endpoint[node_id] = addr
  158. else:
  159. if node_id in self.replacement_nodes:
  160. del self.replacement_nodes[node_id]
  161. self.replacement_nodes[node_id] = addr
  162. return False
  163. return True
  164. def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
  165. """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
  166. for uid, endpoint in self.nodes_to_endpoint.items():
  167. if uid not in self.nodes_requested_for_ping:
  168. self.nodes_requested_for_ping.add(uid)
  169. return uid, endpoint
  170. def __getitem__(self, node_id: DHTID) -> Endpoint:
  171. return self.nodes_to_endpoint[node_id] if node_id in self.nodes_to_endpoint else self.replacement_nodes[node_id]
  172. def __delitem__(self, node_id: DHTID):
  173. if not (node_id in self.nodes_to_endpoint or node_id in self.replacement_nodes):
  174. raise KeyError(f"KBucket does not contain node id={node_id}.")
  175. if node_id in self.replacement_nodes:
  176. del self.replacement_nodes[node_id]
  177. if node_id in self.nodes_to_endpoint:
  178. del self.nodes_to_endpoint[node_id]
  179. if self.replacement_nodes:
  180. newnode_id, newnode = self.replacement_nodes.popitem()
  181. self.nodes_to_endpoint[newnode_id] = newnode
  182. def split(self) -> Tuple[KBucket, KBucket]:
  183. """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
  184. midpoint = (self.lower + self.upper) // 2
  185. assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
  186. left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
  187. right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
  188. for node_id, addr in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
  189. bucket = left if int(node_id) <= midpoint else right
  190. bucket.add_or_update_node(node_id, addr)
  191. return left, right
  192. def __repr__(self):
  193. return f"{self.__class__.__name__}({len(self.nodes_to_endpoint)} nodes" \
  194. f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
  195. f" lower={hex(self.lower)}, upper={hex(self.upper)})"
  196. class DHTID(int):
  197. HASH_FUNC = hashlib.sha1
  198. HASH_NBYTES = 20 # SHA1 produces a 20-byte (aka 160bit) number
  199. RANGE = MIN, MAX = 0, 2 ** (HASH_NBYTES * 8) # inclusive min, exclusive max
  200. def __new__(cls, value: int):
  201. assert cls.MIN <= value < cls.MAX, f"DHTID must be in [{cls.MIN}, {cls.MAX}) but got {value}"
  202. return super().__new__(cls, value)
  203. @classmethod
  204. def generate(cls, source: Optional[Any] = None, nbits: int = 255):
  205. """
  206. Generates random uid based on SHA1
  207. :param source: if provided, converts this value to bytes and uses it as input for hashing function;
  208. by default, generates a random dhtid from :nbits: random bits
  209. """
  210. source = random.getrandbits(nbits).to_bytes(nbits, byteorder='big') if source is None else source
  211. source = PickleSerializer.dumps(source) if not isinstance(source, bytes) else source
  212. raw_uid = cls.HASH_FUNC(source).digest()
  213. return cls(int(raw_uid.hex(), 16))
  214. def xor_distance(self, other: Union[DHTID, Sequence[DHTID]]) -> Union[int, List[int]]:
  215. """
  216. :param other: one or multiple DHTIDs. If given multiple DHTIDs as other, this function
  217. will compute distance from self to each of DHTIDs in other.
  218. :return: a number or a list of numbers whose binary representations equal bitwise xor between DHTIDs.
  219. """
  220. if isinstance(other, Iterable):
  221. return list(map(self.xor_distance, other))
  222. return int(self) ^ int(other)
  223. @classmethod
  224. def longest_common_prefix_length(cls, *ids: DHTID) -> int:
  225. ids_bits = [bin(uid)[2:].rjust(8 * cls.HASH_NBYTES, '0') for uid in ids]
  226. return len(os.path.commonprefix(ids_bits))
  227. def to_bytes(self, length=HASH_NBYTES, byteorder='big', *, signed=False) -> bytes:
  228. """ A standard way to serialize DHTID into bytes """
  229. return super().to_bytes(length, byteorder, signed=signed)
  230. @classmethod
  231. def from_bytes(cls, raw: bytes, byteorder='big', *, signed=False) -> DHTID:
  232. """ reverse of to_bytes """
  233. return DHTID(super().from_bytes(raw, byteorder=byteorder, signed=signed))
  234. def __repr__(self):
  235. return f"{self.__class__.__name__}({hex(self)})"
  236. def __bytes__(self):
  237. return self.to_bytes()
  238. DHTKey, DHTValue, DHTExpiration, BinaryDHTID, BinaryDHTValue, = Any, Any, float, bytes, bytes # flavour types
  239. get_dht_time = time.time # time used by all dht functionality. You can replace this with any infrastructure-wide time