routing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. """ Utlity data structures to represent DHT nodes (peers), data keys, and routing tables. """
  2. from __future__ import annotations
  3. import hashlib
  4. import heapq
  5. import os
  6. import random
  7. from collections.abc import Iterable
  8. from itertools import chain
  9. from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
  10. from hivemind.p2p import PeerID
  11. from hivemind.utils import DHTExpiration, MSGPackSerializer, get_dht_time
  12. DHTKey = Subkey = DHTValue = Any
  13. BinaryDHTID = BinaryDHTValue = bytes
  14. class RoutingTable:
  15. """
  16. A data structure that contains DHT peers bucketed according to their distance to node_id.
  17. Follows Kademlia routing table as described in https://pdos.csail.mit.edu/~petar/papers/maymounkov-kademlia-lncs.pdf
  18. :param node_id: node id used to measure distance
  19. :param bucket_size: parameter $k$ from Kademlia paper Section 2.2
  20. :param depth_modulo: parameter $b$ from Kademlia paper Section 2.2.
  21. :note: you can find a more detailed description of parameters in DHTNode, see node.py
  22. """
  23. def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
  24. self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
  25. self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
  26. self.peer_id_to_uid: Dict[PeerID, DHTID] = dict() # all nodes currently in buckets, including replacements
  27. self.uid_to_peer_id: Dict[DHTID, PeerID] = dict() # all nodes currently in buckets, including replacements
  28. def get_bucket_index(self, node_id: DHTID) -> int:
  29. """Get the index of the bucket that the given node would fall into."""
  30. lower_index, upper_index = 0, len(self.buckets)
  31. while upper_index - lower_index > 1:
  32. pivot_index = (lower_index + upper_index + 1) // 2
  33. if node_id >= self.buckets[pivot_index].lower:
  34. lower_index = pivot_index
  35. else: # node_id < self.buckets[pivot_index].lower
  36. upper_index = pivot_index
  37. assert upper_index - lower_index == 1
  38. return lower_index
  39. def add_or_update_node(self, node_id: DHTID, peer_id: PeerID) -> Optional[Tuple[DHTID, PeerID]]:
  40. """
  41. Update routing table after an incoming request from :peer_id: or outgoing request to :peer_id:
  42. :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
  43. :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
  44. If this method returned a node to be ping-ed, the protocol will ping it to check and either move it to
  45. the start of the table or remove that node and replace it with
  46. """
  47. bucket_index = self.get_bucket_index(node_id)
  48. bucket = self.buckets[bucket_index]
  49. store_success = bucket.add_or_update_node(node_id, peer_id)
  50. if node_id in bucket.nodes_to_peer_id or node_id in bucket.replacement_nodes:
  51. # if we added node to bucket or as a replacement, throw it into lookup dicts as well
  52. self.uid_to_peer_id[node_id] = peer_id
  53. self.peer_id_to_uid[peer_id] = node_id
  54. if not store_success:
  55. # Per section 4.2 of paper, split if the bucket has node's own id in its range
  56. # or if bucket depth is not congruent to 0 mod $b$
  57. if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
  58. self.split_bucket(bucket_index)
  59. return self.add_or_update_node(node_id, peer_id)
  60. # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
  61. return bucket.request_ping_node()
  62. def split_bucket(self, index: int) -> None:
  63. """Split bucket range in two equal parts and reassign nodes to the appropriate half"""
  64. first, second = self.buckets[index].split()
  65. self.buckets[index] = first
  66. self.buckets.insert(index + 1, second)
  67. def get(self, *, node_id: Optional[DHTID] = None, peer_id: Optional[PeerID] = None, default=None):
  68. """Find peer_id for a given DHTID or vice versa"""
  69. assert (node_id is None) != (peer_id is None), "Please specify either node_id or peer_id, but not both"
  70. if node_id is not None:
  71. return self.uid_to_peer_id.get(node_id, default)
  72. else:
  73. return self.peer_id_to_uid.get(peer_id, default)
  74. def __getitem__(self, item: Union[DHTID, PeerID]) -> Union[PeerID, DHTID]:
  75. """Find peer_id for a given DHTID or vice versa"""
  76. return self.uid_to_peer_id[item] if isinstance(item, DHTID) else self.peer_id_to_uid[item]
  77. def __setitem__(self, node_id: DHTID, peer_id: PeerID) -> NotImplementedError:
  78. raise NotImplementedError(
  79. "RoutingTable doesn't support direct item assignment. Use table.try_add_node instead"
  80. )
  81. def __contains__(self, item: Union[DHTID, PeerID]) -> bool:
  82. return (item in self.uid_to_peer_id) if isinstance(item, DHTID) else (item in self.peer_id_to_uid)
  83. def __delitem__(self, node_id: DHTID):
  84. del self.buckets[self.get_bucket_index(node_id)][node_id]
  85. node_peer_id = self.uid_to_peer_id.pop(node_id)
  86. if self.peer_id_to_uid.get(node_peer_id) == node_id:
  87. del self.peer_id_to_uid[node_peer_id]
  88. def get_nearest_neighbors(
  89. self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None
  90. ) -> List[Tuple[DHTID, PeerID]]:
  91. """
  92. Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
  93. :param query_id: find neighbors of this node
  94. :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
  95. :param exclude: if True, results will not contain query_node_id even if it is in table
  96. :return: a list of tuples (node_id, peer_id) for up to k neighbors sorted from nearest to farthest
  97. """
  98. # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
  99. candidates: List[Tuple[int, DHTID, PeerID]] = [] # min-heap based on xor distance to query_id
  100. # step 1: add current bucket to the candidates heap
  101. nearest_index = self.get_bucket_index(query_id)
  102. nearest_bucket = self.buckets[nearest_index]
  103. for node_id, peer_id in nearest_bucket.nodes_to_peer_id.items():
  104. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
  105. # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
  106. left_index, right_index = nearest_index, nearest_index + 1 # bucket indices considered, [left, right)
  107. current_lower, current_upper, current_depth = nearest_bucket.lower, nearest_bucket.upper, nearest_bucket.depth
  108. while current_depth > 0 and len(candidates) < k + int(exclude is not None):
  109. split_direction = current_lower // 2 ** (DHTID.HASH_NBYTES * 8 - current_depth) % 2
  110. # ^-- current leaf direction from pre-leaf node, 0 = left, 1 = right
  111. current_depth -= 1 # traverse one level closer to the root and add all child nodes to the candidates heap
  112. if split_direction == 0: # leaf was split on the left, merge its right peer(s)
  113. current_upper += current_upper - current_lower
  114. while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
  115. for node_id, peer_id in self.buckets[right_index].nodes_to_peer_id.items():
  116. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
  117. right_index += 1
  118. # note: we may need to add more than one bucket if they are on a lower depth level
  119. assert self.buckets[right_index - 1].upper == current_upper
  120. else: # split_direction == 1, leaf was split on the right, merge its left peer(s)
  121. current_lower -= current_upper - current_lower
  122. while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
  123. left_index -= 1 # note: we may need to add more than one bucket if they are on a lower depth level
  124. for node_id, peer_id in self.buckets[left_index].nodes_to_peer_id.items():
  125. heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
  126. assert self.buckets[left_index].lower == current_lower
  127. # step 3: select k nearest vertices from candidates heap
  128. heap_top: List[Tuple[int, DHTID, PeerID]] = heapq.nsmallest(k + int(exclude is not None), candidates)
  129. return [(node, peer_id) for _, node, peer_id in heap_top if node != exclude][:k]
  130. def __repr__(self):
  131. bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
  132. return (
  133. f"{self.__class__.__name__}(node_id={self.node_id}, bucket_size={self.bucket_size},"
  134. f" modulo={self.depth_modulo},\nbuckets=[\n{bucket_info})"
  135. )
  136. class KBucket:
  137. """
  138. A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
  139. Maps DHT node ids to their peer_ids
  140. """
  141. def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
  142. assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
  143. self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
  144. self.nodes_to_peer_id: Dict[DHTID, PeerID] = {}
  145. self.replacement_nodes: Dict[DHTID, PeerID] = {}
  146. self.nodes_requested_for_ping: Set[DHTID] = set()
  147. self.last_updated = get_dht_time()
  148. def has_in_range(self, node_id: DHTID):
  149. """Check if node_id is between this bucket's lower and upper bounds"""
  150. return self.lower <= node_id < self.upper
  151. def add_or_update_node(self, node_id: DHTID, peer_id: PeerID) -> bool:
  152. """
  153. Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
  154. If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
  155. :param node_id: dht node identifier that should be added or moved to the front of bucket
  156. :param peer_id: network address associated with that node id
  157. :note: this function has a side-effect of resetting KBucket.last_updated time
  158. """
  159. if node_id in self.nodes_requested_for_ping:
  160. self.nodes_requested_for_ping.remove(node_id)
  161. self.last_updated = get_dht_time()
  162. if node_id in self.nodes_to_peer_id:
  163. del self.nodes_to_peer_id[node_id]
  164. self.nodes_to_peer_id[node_id] = peer_id
  165. elif len(self.nodes_to_peer_id) < self.size:
  166. self.nodes_to_peer_id[node_id] = peer_id
  167. else:
  168. if node_id in self.replacement_nodes:
  169. del self.replacement_nodes[node_id]
  170. self.replacement_nodes[node_id] = peer_id
  171. return False
  172. return True
  173. def request_ping_node(self) -> Optional[Tuple[DHTID, PeerID]]:
  174. """:returns: least-recently updated node that isn't already being pinged right now -- if such node exists"""
  175. for uid, peer_id in self.nodes_to_peer_id.items():
  176. if uid not in self.nodes_requested_for_ping:
  177. self.nodes_requested_for_ping.add(uid)
  178. return uid, peer_id
  179. def __getitem__(self, node_id: DHTID) -> PeerID:
  180. return self.nodes_to_peer_id[node_id] if node_id in self.nodes_to_peer_id else self.replacement_nodes[node_id]
  181. def __delitem__(self, node_id: DHTID):
  182. if not (node_id in self.nodes_to_peer_id or node_id in self.replacement_nodes):
  183. raise KeyError(f"KBucket does not contain node id={node_id}")
  184. if node_id in self.replacement_nodes:
  185. del self.replacement_nodes[node_id]
  186. if node_id in self.nodes_to_peer_id:
  187. del self.nodes_to_peer_id[node_id]
  188. if self.replacement_nodes:
  189. newnode_id, newnode = self.replacement_nodes.popitem()
  190. self.nodes_to_peer_id[newnode_id] = newnode
  191. def split(self) -> Tuple[KBucket, KBucket]:
  192. """Split bucket over midpoint, rounded down, assign nodes to according to their id"""
  193. midpoint = (self.lower + self.upper) // 2
  194. assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
  195. left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
  196. right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
  197. for node_id, peer_id in chain(self.nodes_to_peer_id.items(), self.replacement_nodes.items()):
  198. bucket = left if int(node_id) <= midpoint else right
  199. bucket.add_or_update_node(node_id, peer_id)
  200. return left, right
  201. def __repr__(self):
  202. return (
  203. f"{self.__class__.__name__}({len(self.nodes_to_peer_id)} nodes"
  204. f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}"
  205. f" lower={hex(self.lower)}, upper={hex(self.upper)})"
  206. )
  207. class DHTID(int):
  208. HASH_FUNC = hashlib.sha1
  209. HASH_NBYTES = 20 # SHA1 produces a 20-byte (aka 160bit) number
  210. RANGE = MIN, MAX = 0, 2 ** (HASH_NBYTES * 8) # inclusive min, exclusive max
  211. def __new__(cls, value: int):
  212. assert cls.MIN <= value < cls.MAX, f"DHTID must be in [{cls.MIN}, {cls.MAX}) but got {value}"
  213. return super().__new__(cls, value)
  214. @classmethod
  215. def generate(cls, source: Optional[Any] = None, nbits: int = 255):
  216. """
  217. Generates random uid based on SHA1
  218. :param source: if provided, converts this value to bytes and uses it as input for hashing function;
  219. by default, generates a random dhtid from :nbits: random bits
  220. """
  221. source = random.getrandbits(nbits).to_bytes(nbits, byteorder="big") if source is None else source
  222. source = MSGPackSerializer.dumps(source) if not isinstance(source, bytes) else source
  223. raw_uid = cls.HASH_FUNC(source).digest()
  224. return cls(int(raw_uid.hex(), 16))
  225. def xor_distance(self, other: Union[DHTID, Sequence[DHTID]]) -> Union[int, List[int]]:
  226. """
  227. :param other: one or multiple DHTIDs. If given multiple DHTIDs as other, this function
  228. will compute distance from self to each of DHTIDs in other.
  229. :return: a number or a list of numbers whose binary representations equal bitwise xor between DHTIDs.
  230. """
  231. if isinstance(other, Iterable):
  232. return list(map(self.xor_distance, other))
  233. return int(self) ^ int(other)
  234. @classmethod
  235. def longest_common_prefix_length(cls, *ids: DHTID) -> int:
  236. ids_bits = [bin(uid)[2:].rjust(8 * cls.HASH_NBYTES, "0") for uid in ids]
  237. return len(os.path.commonprefix(ids_bits))
  238. def to_bytes(self, length=HASH_NBYTES, byteorder="big", *, signed=False) -> bytes:
  239. """A standard way to serialize DHTID into bytes"""
  240. return super().to_bytes(length, byteorder, signed=signed)
  241. @classmethod
  242. def from_bytes(cls, raw: bytes, byteorder="big", *, signed=False) -> DHTID:
  243. """reverse of to_bytes"""
  244. return DHTID(super().from_bytes(raw, byteorder=byteorder, signed=signed))
  245. def __repr__(self):
  246. return f"{self.__class__.__name__}({hex(self)})"
  247. def __bytes__(self):
  248. return self.to_bytes()