node.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from collections import namedtuple
  5. from typing import Optional, Tuple, List, Dict, Collection, Union, Set
  6. from warnings import warn
  7. from hivemind.dht.protocol import DHTProtocol
  8. from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
  9. from hivemind.dht.traverse import traverse_dht
  10. from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
  11. class DHTNode:
  12. """
  13. A low-level class that represents a DHT participant. Please see DHTNode.create for parameters
  14. Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
  15. :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
  16. For example, an expert alive timestamp that emitted by the Server responsible for that expert.
  17. Such metadata does not require regular maintenance by peers, persistence on shutdown.
  18. Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
  19. Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time(), UnixTime by default
  20. DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
  21. Compared to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
  22. * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
  23. * store - send several (key, value, expiration) pairs to the same peer (like Kademlia STORE, but in bulk)
  24. * find - request one or several keys, get values & expiration (if peer finds it locally) and :bucket_size: of
  25. nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
  26. This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
  27. Formally, DHTNode follows the following contract:
  28. - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
  29. IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
  30. - when requested to store(key: value, expiration), a node must store (key => value) at until expiration time
  31. or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
  32. has the same key with newer expiration, the older key will not be stored. Return True if stored, False if refused;
  33. - when requested to store(key: value, expiration, in_cache=True), stores (key => value) in a separate "cache".
  34. Cache operates same as regular storage, but it has a limited size and evicts least recently used nodes when full;
  35. """
  36. # fmt:off
  37. node_id: DHTID; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; num_workers: int
  38. refresh_timeout: float; protocol: DHTProtocol
  39. serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
  40. # fmt:on
  41. @classmethod
  42. async def create(
  43. cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
  44. bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
  45. wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
  46. num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
  47. listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
  48. """
  49. :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
  50. :param initial_peers: connects to these peers to populate routing table, defaults to no peers
  51. :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
  52. either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
  53. Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
  54. :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
  55. :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
  56. :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
  57. Reduce this value if your RPC requests register no response despite the peer sending the response.
  58. :param wait_timeout: a kademlia rpc request is deemed lost if we did not recieve a reply in this many seconds
  59. :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
  60. if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
  61. :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
  62. :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
  63. :param cache_locally: if True, caches all values (stored or found) in a node-local cache
  64. :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
  65. nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
  66. :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
  67. :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
  68. if False, this node will refuse any incoming request, effectively being only a "client"
  69. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  70. :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
  71. see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
  72. :param kwargs: extra parameters used in grpc.aio.server
  73. """
  74. self = cls(_initialized_with_create=True)
  75. self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
  76. self.num_replicas, self.num_workers = num_replicas, num_workers
  77. self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
  78. self.refresh_timeout = refresh_timeout
  79. self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
  80. parallel_rpc, cache_size, listen, listen_on, **kwargs)
  81. self.port = self.protocol.port
  82. if initial_peers:
  83. # stage 1: ping initial_peers, add each other to the routing table
  84. bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
  85. start_time = get_dht_time()
  86. ping_tasks = map(self.protocol.call_ping, initial_peers)
  87. finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
  88. # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
  89. if unfinished_pings:
  90. finished_in_time, stragglers = await asyncio.wait(
  91. unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
  92. for straggler in stragglers:
  93. straggler.cancel()
  94. finished_pings |= finished_in_time
  95. if not finished_pings:
  96. warn("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
  97. # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
  98. # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
  99. # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
  100. await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
  101. asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
  102. return_when=asyncio.FIRST_COMPLETED)
  103. if self.refresh_timeout is not None:
  104. asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
  105. return self
  106. def __init__(self, *, _initialized_with_create=False):
  107. """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
  108. assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
  109. super().__init__()
  110. async def shutdown(self, timeout=None):
  111. """ Process existing requests, close all connections and stop the server """
  112. await self.protocol.shutdown(timeout)
  113. async def find_nearest_nodes(
  114. self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
  115. num_workers: Optional[int] = None, node_to_endpoint: Optional[Dict[DHTID, Endpoint]] = None,
  116. exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, Endpoint]]:
  117. """
  118. :param queries: find k nearest nodes for each of these DHTIDs
  119. :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
  120. :param beam_size: replacement for self.beam_size, see traverse_dht beam_size param
  121. :param num_workers: replacement for self.num_workers, see traverse_dht num_workers param
  122. :param node_to_endpoint: if specified, uses this dict[node_id => endpoint] as initial peers
  123. :param exclude_self: if True, nearest nodes will not contain self.node_id (default = use local peers)
  124. :param kwargs: additional params passed to traverse_dht
  125. :returns: for every query, return nearest peers ordered dict[peer DHTID -> network Endpoint], nearest-first
  126. """
  127. queries = tuple(queries)
  128. k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
  129. num_workers = num_workers if num_workers is not None else self.num_workers
  130. beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
  131. if k_nearest > beam_size:
  132. warn("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
  133. if node_to_endpoint is None:
  134. node_to_endpoint: Dict[DHTID, Endpoint] = dict()
  135. for query in queries:
  136. node_to_endpoint.update(
  137. self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
  138. async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
  139. response = await self.protocol.call_find(node_to_endpoint[peer], queries)
  140. if not response:
  141. return {query: ([], False) for query in queries}
  142. output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
  143. for query, (_, _, peers) in response.items():
  144. node_to_endpoint.update(peers)
  145. output[query] = list(peers.keys()), False # False means "do not interrupt search"
  146. return output
  147. nearest_nodes_per_query, visited_nodes = await traverse_dht(
  148. queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
  149. queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
  150. visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
  151. nearest_nodes_with_endpoints = {}
  152. for query, nearest_nodes in nearest_nodes_per_query.items():
  153. if not exclude_self:
  154. nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
  155. node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
  156. nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
  157. return nearest_nodes_with_endpoints
  158. async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
  159. """
  160. Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
  161. :note: store is a simplified interface to store_many, all kwargs are be forwarded there
  162. :returns: True if store succeeds, False if it fails (due to no response or newer value)
  163. """
  164. store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
  165. return store_ok[key]
  166. async def store_many(
  167. self, keys: List[DHTKey], values: List[DHTValue], expiration: Union[DHTExpiration, List[DHTExpiration]],
  168. exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
  169. """
  170. Traverse DHT to find up to best nodes to store multiple (key, value, expiration) pairs.
  171. :param keys: arbitrary serializable keys associated with each value
  172. :param values: serializable "payload" for each key
  173. :param expiration: either one expiration time for all keys or individual expiration times (see class doc)
  174. :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
  175. :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
  176. :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
  177. :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background
  178. if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
  179. :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
  180. """
  181. expiration = [expiration] * len(keys) if isinstance(expiration, DHTExpiration) else expiration
  182. assert len(keys) == len(values) == len(expiration), "Please provide equal number of keys, values and expiration"
  183. key_ids = list(map(DHTID.generate, keys))
  184. id_to_original_key = dict(zip(key_ids, keys))
  185. binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
  186. expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration)}
  187. unfinished_key_ids = set(key_ids) # we use this set to ensure that each store request is finished
  188. store_ok = {key: False for key in keys} # outputs, updated during search
  189. store_finished_events = {key: asyncio.Event() for key in keys}
  190. if self.cache_locally:
  191. for key_id in key_ids:
  192. self.protocol.cache.store(key_id, binary_values_by_key_id[key_id], expiration_by_key_id[key_id])
  193. # pre-populate node_to_endpoint
  194. node_to_endpoint: Dict[DHTID, Endpoint] = dict()
  195. for key_id in key_ids:
  196. node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
  197. key_id, self.protocol.bucket_size, exclude=self.node_id))
  198. async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
  199. """ This will be called once per key when find_nearest_nodes is done for a particular node """
  200. # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
  201. assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
  202. assert self.node_id not in nearest_nodes
  203. unfinished_key_ids.remove(key_id)
  204. # ensure k nodes stored the value, optionally include self.node_id as a candidate
  205. num_successful_stores = 0
  206. pending_store_tasks = set()
  207. store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
  208. key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest
  209. while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
  210. # spawn enough tasks to cover all replicas
  211. while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
  212. node_id: DHTID = store_candidates.pop() # nearest untried candidate
  213. if node_id == self.node_id:
  214. self.protocol.storage.store(key_id, binary_values_by_key_id[key_id],
  215. expiration_by_key_id[key_id])
  216. store_ok[id_to_original_key[key_id]] = True
  217. num_successful_stores += 1
  218. if not await_all_replicas:
  219. store_finished_events[id_to_original_key[key_id]].set()
  220. else:
  221. pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
  222. node_to_endpoint[node_id], [key_id], [binary_values_by_key_id[key_id]],
  223. [expiration_by_key_id[key_id]])))
  224. # await nearest task. If it fails, dispatch more on the next iteration
  225. if pending_store_tasks:
  226. finished_store_tasks, pending_store_tasks = await asyncio.wait(
  227. pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
  228. for task in finished_store_tasks:
  229. if task.result()[0]: # if store succeeded
  230. store_ok[id_to_original_key[key_id]] = True
  231. num_successful_stores += 1
  232. if not await_all_replicas:
  233. store_finished_events[id_to_original_key[key_id]].set()
  234. store_finished_events[id_to_original_key[key_id]].set()
  235. asyncio.create_task(self.find_nearest_nodes(
  236. queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
  237. found_callback=on_found, exclude_self=exclude_self, **kwargs))
  238. await asyncio.wait([evt.wait() for evt in store_finished_events.values()]) # await one (or all) store accepts
  239. assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
  240. return store_ok
  241. async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
  242. """
  243. Search for a key across DHT and return either first or latest entry.
  244. :param key: same key as in node.store(...)
  245. :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
  246. :param kwargs: parameters forwarded to get_many
  247. :returns: (value, expiration time); if value was not found, returns (None, None)
  248. """
  249. if latest:
  250. kwargs["sufficient_expiration_time"] = float('inf')
  251. result = await self.get_many([key])
  252. return result[key]
  253. async def get_many(
  254. self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
  255. num_workers: Optional[int] = None, beam_size: Optional[int] = None
  256. ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
  257. """
  258. :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
  259. :param sufficient_expiration_time: if the search finds a value that expires after this time,
  260. default = time of call, find any value that did not expire by the time of call
  261. If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
  262. :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
  263. :param num_workers: override for default num_workers, see traverse_dht num_workers param
  264. :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
  265. :note: in order to check if get returned a value, please check (expiration_time is None)
  266. """
  267. key_ids = [DHTID.generate(key) for key in keys]
  268. id_to_original_key = dict(zip(key_ids, keys))
  269. sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
  270. beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
  271. num_workers = num_workers if num_workers is not None else self.num_workers
  272. # search metadata
  273. unfinished_key_ids = set(key_ids) # track key ids for which the search is not terminated
  274. node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all queries
  275. SearchResult = namedtuple("SearchResult", ["binary_value", "expiration", "source_node_id"])
  276. latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
  277. # stage 1: value can be stored in our local cache
  278. for key_id in key_ids:
  279. maybe_value, maybe_expiration = self.protocol.storage.get(key_id)
  280. if maybe_expiration is None:
  281. maybe_value, maybe_expiration = self.protocol.cache.get(key_id)
  282. if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
  283. latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, self.node_id)
  284. if maybe_expiration >= sufficient_expiration_time:
  285. unfinished_key_ids.remove(key_id)
  286. # stage 2: traverse the DHT for any unfinished keys
  287. for key_id in unfinished_key_ids:
  288. node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
  289. key_id, self.protocol.bucket_size, exclude=self.node_id))
  290. async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
  291. queries = list(queries)
  292. response = await self.protocol.call_find(node_to_endpoint[peer], queries)
  293. if not response:
  294. return {query: ([], False) for query in queries}
  295. output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
  296. for key_id, (maybe_value, maybe_expiration, peers) in response.items():
  297. node_to_endpoint.update(peers)
  298. if maybe_expiration is not None and maybe_expiration > latest_results[key_id].expiration:
  299. latest_results[key_id] = SearchResult(maybe_value, maybe_expiration, peer)
  300. should_interrupt = (latest_results[key_id].expiration >= sufficient_expiration_time)
  301. output[key_id] = list(peers.keys()), should_interrupt
  302. return output
  303. nearest_nodes_per_query, visited_nodes = await traverse_dht(
  304. queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
  305. beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
  306. get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
  307. # stage 3: cache any new results depending on caching parameters
  308. for key_id, nearest_nodes in nearest_nodes_per_query.items():
  309. latest_value_bytes, latest_expiration, latest_node_id = latest_results[key_id]
  310. should_cache = latest_expiration >= sufficient_expiration_time # if we found a newer value, cache it
  311. if should_cache and self.cache_locally:
  312. self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration)
  313. if should_cache and self.cache_nearest:
  314. num_cached_nodes = 0
  315. for node_id in nearest_nodes:
  316. if node_id == latest_node_id:
  317. continue
  318. asyncio.create_task(self.protocol.call_store(
  319. node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration], in_cache=True))
  320. num_cached_nodes += 1
  321. if num_cached_nodes >= self.cache_nearest:
  322. break
  323. # stage 4: deserialize data and assemble function output
  324. find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
  325. for key_id, (latest_value_bytes, latest_expiration, _) in latest_results.items():
  326. if latest_expiration != -float('inf'):
  327. find_result[id_to_original_key[key_id]] = self.serializer.loads(latest_value_bytes), latest_expiration
  328. else:
  329. find_result[id_to_original_key[key_id]] = None, None
  330. return find_result
  331. async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
  332. """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
  333. while period is not None: # if None run once, otherwise run forever
  334. refresh_time = get_dht_time()
  335. staleness_threshold = refresh_time - period
  336. stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
  337. if bucket.last_updated < staleness_threshold]
  338. for bucket in stale_buckets:
  339. refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
  340. await self.find_nearest_nodes(refresh_id)
  341. await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))