traverse.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. """ Utility functions for crawling DHT nodes, used to get and store keys in a DHT """
  2. import asyncio
  3. import heapq
  4. from collections import Counter
  5. from typing import Dict, Awaitable, Callable, Any, Tuple, List, Set, Collection, Optional
  6. from hivemind.dht.routing import DHTID
  7. ROOT = 0 # alias for heap root
  8. async def simple_traverse_dht(query_id: DHTID, initial_nodes: Collection[DHTID], beam_size: int,
  9. get_neighbors: Callable[[DHTID], Awaitable[Tuple[Collection[DHTID], bool]]],
  10. visited_nodes: Collection[DHTID] = ()) -> Tuple[List[DHTID], Set[DHTID]]:
  11. """
  12. Traverse the DHT graph using get_neighbors function, find :beam_size: nearest nodes according to DHTID.xor_distance.
  13. :note: This is a simplified (but working) algorithm provided for documentation purposes. Actual DHTNode uses
  14. `traverse_dht` - a generalization of this this algorithm that allows multiple queries and concurrent workers.
  15. :param query_id: search query, find k_nearest neighbors of this DHTID
  16. :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
  17. :param beam_size: beam search will not give up until it exhausts this many nearest nodes (to query_id) from the heap
  18. Recommended value: A beam size of k_nearest * (2-5) will yield near-perfect results.
  19. :param get_neighbors: A function that returns neighbors of a given node and controls beam search stopping criteria.
  20. async def get_neighbors(node: DHTID) -> neighbors_of_that_node: List[DHTID], should_continue: bool
  21. If should_continue is False, beam search will halt and return k_nearest of whatever it found by then.
  22. :param visited_nodes: beam search will neither call get_neighbors on these nodes, nor return them as nearest
  23. :returns: a list of k nearest nodes (nearest to farthest), and a set of all visited nodes (including visited_nodes)
  24. """
  25. visited_nodes = set(visited_nodes) # note: copy visited_nodes because we will add more nodes to this collection.
  26. initial_nodes = [node_id for node_id in initial_nodes if node_id not in visited_nodes]
  27. if not initial_nodes:
  28. return [], visited_nodes
  29. unvisited_nodes = [(distance, uid) for uid, distance in zip(initial_nodes, query_id.xor_distance(initial_nodes))]
  30. heapq.heapify(unvisited_nodes) # nearest-first heap of candidates, unlimited size
  31. nearest_nodes = [(-distance, node_id) for distance, node_id in heapq.nsmallest(beam_size, unvisited_nodes)]
  32. heapq.heapify(nearest_nodes) # farthest-first heap of size beam_size, used for early-stopping and to select results
  33. while len(nearest_nodes) > beam_size:
  34. heapq.heappop(nearest_nodes)
  35. visited_nodes |= set(initial_nodes)
  36. upper_bound = -nearest_nodes[0][0] # distance to farthest element that is still in beam
  37. was_interrupted = False # will set to True if host triggered beam search to stop via get_neighbors
  38. while (not was_interrupted) and len(unvisited_nodes) != 0 and unvisited_nodes[0][0] <= upper_bound:
  39. _, node_id = heapq.heappop(unvisited_nodes) # note: this --^ is the smallest element in heap (see heapq)
  40. neighbors, was_interrupted = await get_neighbors(node_id)
  41. neighbors = [node_id for node_id in neighbors if node_id not in visited_nodes]
  42. visited_nodes.update(neighbors)
  43. for neighbor_id, distance in zip(neighbors, query_id.xor_distance(neighbors)):
  44. if distance <= upper_bound or len(nearest_nodes) < beam_size:
  45. heapq.heappush(unvisited_nodes, (distance, neighbor_id))
  46. heapq_add_or_replace = heapq.heappush if len(nearest_nodes) < beam_size else heapq.heappushpop
  47. heapq_add_or_replace(nearest_nodes, (-distance, neighbor_id))
  48. upper_bound = -nearest_nodes[0][0] # distance to beam_size-th nearest element found so far
  49. return [node_id for _, node_id in heapq.nlargest(beam_size, nearest_nodes)], visited_nodes
  50. async def traverse_dht(
  51. queries: Collection[DHTID], initial_nodes: List[DHTID], beam_size: int, num_workers: int, queries_per_call: int,
  52. get_neighbors: Callable[[DHTID, Collection[DHTID]], Awaitable[Dict[DHTID, Tuple[List[DHTID], bool]]]],
  53. found_callback: Optional[Callable[[DHTID, List[DHTID], Set[DHTID]], Awaitable[Any]]] = None,
  54. await_all_tasks: bool = True, visited_nodes: Optional[Dict[DHTID, Set[DHTID]]] = (),
  55. ) -> Tuple[Dict[DHTID, List[DHTID]], Dict[DHTID, Set[DHTID]]]:
  56. """
  57. Search the DHT for nearest neighbors to :queries: (based on DHTID.xor_distance). Use get_neighbors to request peers.
  58. The algorithm can reuse intermediate results from each query to speed up search for other (similar) queries.
  59. :param queries: a list of search queries, find beam_size neighbors for these DHTIDs
  60. :param initial_nodes: nodes used to pre-populate beam search heap, e.g. [my_own_DHTID, ...maybe_some_peers]
  61. :param beam_size: beam search will not give up until it visits this many nearest nodes (to query_id) from the heap
  62. :param num_workers: run up to this many concurrent get_neighbors requests, each querying one peer for neighbors.
  63. When selecting a peer to request neighbors from, workers try to balance concurrent exploration across queries.
  64. A worker will expand the nearest candidate to a query with least concurrent requests from other workers.
  65. If several queries have the same number of concurrent requests, prefer the one with nearest XOR distance.
  66. :param queries_per_call: workers can pack up to this many queries in one get_neighbors call. These queries contain
  67. the primary query (see num_workers above) and up to `queries_per_call - 1` nearest unfinished queries.
  68. :param get_neighbors: A function that requests a given peer to find nearest neighbors for multiple queries
  69. async def get_neighbors(peer, queries) -> {query1: ([nearest1, nearest2, ...], False), query2: ([...], True)}
  70. For each query in queries, return nearest neighbors (known to a given peer) and a boolean "should_stop" flag
  71. If should_stop is True, traverse_dht will no longer search for this query or request it from other peers.
  72. The search terminates iff each query is either stopped via should_stop or finds beam_size nearest nodes.
  73. :param found_callback: if specified, call this callback for each finished query the moment it finishes or is stopped
  74. More specifically, run asyncio.create_task(found_found_callback(query, nearest_to_query, visited_for_query))
  75. Using this callback allows one to process results faster before traverse_dht is finishes for all queries.
  76. :param await_all_tasks: if True, wait for all tasks to finish before returning, otherwise returns after finding
  77. nearest neighbors and finishes the remaining tasks (callbacks and queries to known-but-unvisited nodes)
  78. :param visited_nodes: for each query, do not call get_neighbors on these nodes, nor return them among nearest.
  79. :note: the source code of this function can get tricky to read. Take a look at `simple_traverse_dht` function
  80. for reference. That function implements a special case of traverse_dht with a single query and one worker.
  81. :returns: a dict of nearest nodes, and another dict of visited nodes
  82. nearest nodes: { query -> a list of up to beam_size nearest nodes, ordered nearest-first }
  83. visited nodes: { query -> a set of all nodes that received requests for a given query }
  84. """
  85. if len(queries) == 0:
  86. return {}, dict(visited_nodes)
  87. unfinished_queries = set(queries) # all queries that haven't triggered finish_search yet
  88. candidate_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {} # heap: unvisited nodes, ordered nearest-to-farthest
  89. nearest_nodes: Dict[DHTID, List[Tuple[int, DHTID]]] = {} # heap: top-k nearest nodes, farthest-to-nearest
  90. known_nodes: Dict[DHTID, Set[DHTID]] = {} # all nodes ever added to the heap (for deduplication)
  91. visited_nodes: Dict[DHTID, Set[DHTID]] = dict(visited_nodes) # where we requested get_neighbors for a given query
  92. pending_tasks = set() # all active tasks (get_neighbors and found_callback)
  93. active_workers = Counter({q: 0 for q in queries}) # count workers that search for this query
  94. search_finished_event = asyncio.Event() # used to immediately stop all workers when the search is finished
  95. heap_updated_event = asyncio.Event() # if a worker has no nodes to explore, it will await other workers
  96. heap_updated_event.set()
  97. # initialize data structures
  98. for query in queries:
  99. distances = query.xor_distance(initial_nodes)
  100. candidate_nodes[query] = list(zip(distances, initial_nodes))
  101. nearest_nodes[query] = list(zip([-d for d in distances], initial_nodes))
  102. heapq.heapify(candidate_nodes[query])
  103. heapq.heapify(nearest_nodes[query])
  104. while len(nearest_nodes[query]) > beam_size:
  105. heapq.heappop(nearest_nodes[query])
  106. known_nodes[query] = set(initial_nodes)
  107. visited_nodes[query] = set(visited_nodes.get(query, ()))
  108. def heuristic_priority(heap_query: DHTID):
  109. """ Workers prioritize expanding nodes that lead to under-explored queries (by other workers) """
  110. if len(candidate_nodes[heap_query]) == 0:
  111. return float('inf'), float('inf')
  112. else: # prefer candidates in heaps with least number of concurrent workers, break ties by distance to query
  113. return active_workers[heap_query], candidate_nodes[heap_query][ROOT][0]
  114. def upper_bound(query: DHTID):
  115. """ Any node that is farther from query than upper_bound(query) will not be added to heaps """
  116. return -nearest_nodes[query][ROOT][0] if len(nearest_nodes[query]) >= beam_size else float('inf')
  117. def finish_search(query):
  118. """ Remove query from a list of targets """
  119. unfinished_queries.remove(query)
  120. if len(unfinished_queries) == 0:
  121. search_finished_event.set()
  122. if found_callback:
  123. nearest_neighbors = [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
  124. pending_tasks.add(asyncio.create_task(found_callback(query, nearest_neighbors, set(visited_nodes))))
  125. async def worker():
  126. while unfinished_queries:
  127. # select the heap based on priority
  128. chosen_query: DHTID = min(unfinished_queries, key=heuristic_priority)
  129. if len(candidate_nodes[chosen_query]) == 0: # if there are no peers to explore...
  130. other_workers_pending = active_workers.most_common(1)[0][1] > 0
  131. if other_workers_pending: # ... wait for other workers (if any) or add more peers
  132. heap_updated_event.clear()
  133. await heap_updated_event.wait()
  134. continue
  135. else: # ... or if there is no hope of new nodes, finish search immediately
  136. for query in list(unfinished_queries):
  137. finish_search(query)
  138. break
  139. # select vertex to be explored
  140. chosen_distance_to_query, chosen_peer = heapq.heappop(candidate_nodes[chosen_query])
  141. if chosen_peer in visited_nodes[chosen_query]:
  142. continue
  143. if chosen_distance_to_query > upper_bound(chosen_query):
  144. finish_search(chosen_query)
  145. continue
  146. # find additional queries to pack in the same request
  147. possible_additional_queries = [query for query in unfinished_queries
  148. if query != chosen_query and chosen_peer not in visited_nodes[query]]
  149. queries_to_call = [chosen_query] + heapq.nsmallest(
  150. queries_per_call - 1, possible_additional_queries, key=chosen_peer.xor_distance)
  151. # update priorities for subsequent workers
  152. active_workers.update(queries_to_call)
  153. for query_to_call in queries_to_call:
  154. visited_nodes[query_to_call].add(chosen_peer)
  155. # get nearest neighbors (over network) and update search heaps. Abort if search finishes early
  156. get_neighbors_task = asyncio.create_task(get_neighbors(chosen_peer, queries_to_call))
  157. pending_tasks.add(get_neighbors_task)
  158. await asyncio.wait([get_neighbors_task, search_finished_event.wait()], return_when=asyncio.FIRST_COMPLETED)
  159. if search_finished_event.is_set():
  160. break # other worker triggered finish_search, we exit immediately
  161. pending_tasks.remove(get_neighbors_task)
  162. # add nearest neighbors to their respective heaps
  163. for query, (neighbors_for_query, should_stop) in get_neighbors_task.result().items():
  164. if should_stop and (query in unfinished_queries):
  165. finish_search(query)
  166. if query not in unfinished_queries:
  167. continue # either we finished search or someone else did while we awaited
  168. for neighbor in neighbors_for_query:
  169. if neighbor not in known_nodes[query]:
  170. known_nodes[query].add(neighbor)
  171. distance = query.xor_distance(neighbor)
  172. if distance <= upper_bound(query) or len(nearest_nodes[query]) < beam_size:
  173. heapq.heappush(candidate_nodes[query], (distance, neighbor))
  174. if len(nearest_nodes[query]) < beam_size:
  175. heapq.heappush(nearest_nodes[query], (-distance, neighbor))
  176. else:
  177. heapq.heappushpop(nearest_nodes[query], (-distance, neighbor))
  178. # we finished processing a request, update priorities for other workers
  179. active_workers.subtract(queries_to_call)
  180. heap_updated_event.set()
  181. # spawn all workers and wait for them to terminate; workers terminate after exhausting unfinished_queries
  182. await asyncio.wait([asyncio.create_task(worker()) for _ in range(num_workers)],
  183. return_when=asyncio.FIRST_COMPLETED) # first worker finishes when the search is over
  184. assert len(unfinished_queries) == 0 and search_finished_event.is_set()
  185. if await_all_tasks:
  186. await asyncio.gather(*pending_tasks)
  187. nearest_neighbors_per_query = {
  188. query: [peer for _, peer in heapq.nlargest(beam_size, nearest_nodes[query])]
  189. for query in queries
  190. }
  191. return nearest_neighbors_per_query, visited_nodes