node.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from collections import defaultdict, Counter
  5. from dataclasses import dataclass, field
  6. from functools import partial
  7. from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
  8. from sortedcontainers import SortedSet
  9. from hivemind.dht.protocol import DHTProtocol
  10. from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
  11. from hivemind.dht.storage import DictionaryDHTValue
  12. from hivemind.dht.traverse import traverse_dht
  13. from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
  14. from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
  15. logger = get_logger(__name__)
  16. class DHTNode:
  17. """
  18. Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
  19. Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
  20. :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
  21. For example, expert heartbeat emitted by a hivemind.Server responsible for that expert.
  22. Such metadata does not require regular maintenance by peers or persistence on shutdown.
  23. Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
  24. Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time() (UnixTime by default)
  25. DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
  26. Similar to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
  27. * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
  28. * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
  29. * find - request one or several keys, get values and expiration (if peer finds it locally) and :bucket_size: of
  30. nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
  31. This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
  32. A DHTNode follows the following contract:
  33. - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
  34. IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
  35. - when requested to store(key: value, expiration_time), a node must store (key => value) at until expiration time
  36. or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
  37. has the same key with newer expiration, store will be rejected. Store returns True if accepted, False if rejected;
  38. - when requested to store(key: value, expiration_time, subkey=subkey), adds a sub-key to a dictionary value type.
  39. Dictionary values can have multiple sub-keys stored by different peers with individual expiration times. A subkey
  40. will be accepted to a dictionary either if there is no such sub-key or if new subkey's expiration is later than
  41. previous expiration under that subkey. See DHTProtocol.call_store for details.
  42. DHTNode also features several (optional) caching policies:
  43. - cache_locally: after GET, store the result in node's own local cache
  44. - cache_nearest: after GET, send the result to this many nearest nodes that don't have that value yet (see Kademlia)
  45. - cache_on_store: after STORE, either save or remove that key from node's own cache depending on store status
  46. - cache_refresh_before_expiry: if a value in cache was used and is about to expire, try to GET it this many seconds
  47. before expiration. The motivation here is that some frequent keys should be always kept in cache to avoid latency.
  48. - reuse_get_requests: if there are several concurrent GET requests, when one request finishes, DHTNode will attempt
  49. to reuse the result of this GET request for other requests with the same key. Useful for batch-parallel requests.
  50. """
  51. # fmt:off
  52. node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
  53. chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
  54. cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
  55. cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
  56. blacklist: Blacklist
  57. # fmt:on
  58. @classmethod
  59. async def create(
  60. cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
  61. bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
  62. wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
  63. cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
  64. cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
  65. blacklist_time: float = 5.0, backoff_rate: float = 2.0,
  66. listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
  67. validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
  68. """
  69. :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
  70. :param initial_peers: connects to these peers to populate routing table, defaults to no peers
  71. :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
  72. either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
  73. Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
  74. :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
  75. :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
  76. :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
  77. Reduce this value if your RPC requests register no response despite the peer sending the response.
  78. :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
  79. :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
  80. if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
  81. :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
  82. :param cache_locally: if True, caches all values (stored or found) in a node-local cache
  83. :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
  84. :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
  85. nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
  86. :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
  87. :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
  88. if they are accessed this many seconds before expiration time.
  89. :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
  90. all concurrent get requests for the same key will reuse the procedure that is currently in progress
  91. :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
  92. :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
  93. :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
  94. :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
  95. :param validate: if True, use initial peers to validate that this node is accessible and synchronized
  96. :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
  97. :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
  98. if False, this node will refuse any incoming request, effectively being only a "client"
  99. :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
  100. :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
  101. :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
  102. see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
  103. :param kwargs: extra parameters used in grpc.aio.server
  104. """
  105. self = cls(_initialized_with_create=True)
  106. self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
  107. self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
  108. self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh
  109. self.reuse_get_requests = reuse_get_requests
  110. self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
  111. # caching policy
  112. self.refresh_timeout = refresh_timeout
  113. self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
  114. self.cache_refresh_before_expiry = cache_refresh_before_expiry
  115. self.blacklist = Blacklist(blacklist_time, backoff_rate)
  116. self.cache_refresh_queue = CacheRefreshQueue()
  117. self.cache_refresh_evt = asyncio.Event()
  118. self.cache_refresh_task = None
  119. self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
  120. parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
  121. self.port = self.protocol.port
  122. if initial_peers:
  123. # stage 1: ping initial_peers, add each other to the routing table
  124. bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
  125. start_time = get_dht_time()
  126. ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
  127. for peer in initial_peers)
  128. finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
  129. # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
  130. if unfinished_pings:
  131. finished_in_time, stragglers = await asyncio.wait(
  132. unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
  133. for straggler in stragglers:
  134. straggler.cancel()
  135. finished_pings |= finished_in_time
  136. if not finished_pings or all(ping.result() is None for ping in finished_pings):
  137. logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
  138. if strict:
  139. for task in asyncio.as_completed(finished_pings):
  140. await task # propagate exceptions
  141. # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
  142. # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
  143. # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
  144. await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
  145. asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
  146. return_when=asyncio.FIRST_COMPLETED)
  147. if self.refresh_timeout is not None:
  148. asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
  149. return self
  150. def __init__(self, *, _initialized_with_create=False):
  151. """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
  152. assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
  153. super().__init__()
  154. async def shutdown(self, timeout=None):
  155. """ Process existing requests, close all connections and stop the server """
  156. self.is_alive = False
  157. if self.protocol.server:
  158. await self.protocol.shutdown(timeout)
  159. async def find_nearest_nodes(
  160. self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
  161. num_workers: Optional[int] = None, node_to_endpoint: Optional[Dict[DHTID, Endpoint]] = None,
  162. exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, Endpoint]]:
  163. """
  164. :param queries: find k nearest nodes for each of these DHTIDs
  165. :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
  166. :param beam_size: replacement for self.beam_size, see traverse_dht beam_size param
  167. :param num_workers: replacement for self.num_workers, see traverse_dht num_workers param
  168. :param node_to_endpoint: if specified, uses this dict[node_id => endpoint] as initial peers
  169. :param exclude_self: if True, nearest nodes will not contain self.node_id (default = use local peers)
  170. :param kwargs: additional params passed to traverse_dht
  171. :returns: for every query, return nearest peers ordered dict[peer DHTID -> network Endpoint], nearest-first
  172. """
  173. queries = tuple(queries)
  174. k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
  175. num_workers = num_workers if num_workers is not None else self.num_workers
  176. beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
  177. if k_nearest > beam_size:
  178. logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
  179. if node_to_endpoint is None:
  180. node_to_endpoint: Dict[DHTID, Endpoint] = dict()
  181. for query in queries:
  182. neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
  183. node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
  184. async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
  185. response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
  186. if not response:
  187. return {query: ([], False) for query in queries}
  188. output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
  189. for query, (_, peers) in response.items():
  190. node_to_endpoint.update(peers)
  191. output[query] = tuple(peers.keys()), False # False means "do not interrupt search"
  192. return output
  193. nearest_nodes_per_query, visited_nodes = await traverse_dht(
  194. queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
  195. queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
  196. visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
  197. nearest_nodes_with_endpoints = {}
  198. for query, nearest_nodes in nearest_nodes_per_query.items():
  199. if not exclude_self:
  200. nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
  201. node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
  202. nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
  203. return nearest_nodes_with_endpoints
  204. async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
  205. subkey: Optional[Subkey] = None, **kwargs) -> bool:
  206. """
  207. Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
  208. :note: store is a simplified interface to store_many, all kwargs are be forwarded there
  209. :returns: True if store succeeds, False if it fails (due to no response or newer value)
  210. """
  211. store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
  212. return store_ok[(key, subkey) if subkey is not None else key]
  213. async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
  214. expiration_time: Union[DHTExpiration, List[DHTExpiration]],
  215. subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
  216. exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
  217. """
  218. Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
  219. :param keys: arbitrary serializable keys associated with each value
  220. :param values: serializable "payload" for each key
  221. :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
  222. :param subkeys: an optional list of same shape as keys. If specified, this
  223. :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
  224. :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
  225. :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
  226. :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background
  227. if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
  228. :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
  229. """
  230. if isinstance(expiration_time, DHTExpiration):
  231. expiration_time = [expiration_time] * len(keys)
  232. if subkeys is None:
  233. subkeys = [None] * len(keys)
  234. assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
  235. "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
  236. key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
  237. for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
  238. key_id_to_data[DHTID.generate(source=key)].append((key, subkey, value, expiration))
  239. unfinished_key_ids = set(key_id_to_data.keys()) # use this set to ensure that each store request is finished
  240. store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)} # outputs, updated during search
  241. store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
  242. # pre-populate node_to_endpoint
  243. node_to_endpoint: Dict[DHTID, Endpoint] = dict()
  244. for key_id in unfinished_key_ids:
  245. node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
  246. key_id, self.protocol.bucket_size, exclude=self.node_id))
  247. async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
  248. """ This will be called once per key when find_nearest_nodes is done for a particular node """
  249. # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
  250. assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
  251. assert self.node_id not in nearest_nodes
  252. unfinished_key_ids.remove(key_id)
  253. # ensure k nodes stored the value, optionally include self.node_id as a candidate
  254. num_successful_stores = 0
  255. pending_store_tasks = set()
  256. store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
  257. key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest
  258. [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
  259. binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
  260. while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
  261. while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
  262. node_id: DHTID = store_candidates.pop() # nearest untried candidate
  263. if node_id == self.node_id:
  264. num_successful_stores += 1
  265. for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
  266. store_ok[original_key, subkey] = self.protocol.storage.store(
  267. key_id, value, expiration_time, subkey=subkey)
  268. if not await_all_replicas:
  269. store_finished_events[original_key, subkey].set()
  270. else:
  271. pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
  272. node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
  273. expiration_time=current_expirations, subkeys=current_subkeys)))
  274. # await nearest task. If it fails, dispatch more on the next iteration
  275. if pending_store_tasks:
  276. finished_store_tasks, pending_store_tasks = await asyncio.wait(
  277. pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
  278. for task in finished_store_tasks:
  279. if task.result() is not None:
  280. num_successful_stores += 1
  281. for subkey, store_status in zip(current_subkeys, task.result()):
  282. store_ok[original_key, subkey] = store_status
  283. if not await_all_replicas:
  284. store_finished_events[original_key, subkey].set()
  285. if self.cache_on_store:
  286. self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
  287. store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
  288. for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
  289. store_finished_events[original_key, subkey].set()
  290. store_task = asyncio.create_task(self.find_nearest_nodes(
  291. queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
  292. found_callback=on_found, exclude_self=exclude_self, **kwargs))
  293. try:
  294. await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
  295. assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
  296. return {(key, subkey) if subkey is not None else key: status or False for (key, subkey), status in store_ok.items()}
  297. except asyncio.CancelledError as e:
  298. store_task.cancel()
  299. raise e
  300. def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
  301. expirations: List[DHTExpiration], store_ok: List[bool]):
  302. """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
  303. store_succeeded = any(store_ok)
  304. is_dictionary = any(subkey is not None for subkey in subkeys)
  305. if store_succeeded and not is_dictionary: # stored a new regular value, cache it!
  306. stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
  307. self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
  308. elif not store_succeeded and not is_dictionary: # store rejected, check if local cache is also obsolete
  309. rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
  310. if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration: # cache would be rejected
  311. self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
  312. elif is_dictionary and key_id in self.protocol.cache: # there can be other keys and we should update
  313. for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
  314. self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
  315. self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
  316. async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:
  317. """
  318. Search for a key across DHT and return either first or latest entry (if found).
  319. :param key: same key as in node.store(...)
  320. :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
  321. :param kwargs: parameters forwarded to get_many_by_id
  322. :returns: (value, expiration time); if value was not found, returns None
  323. """
  324. if latest:
  325. kwargs["sufficient_expiration_time"] = float('inf')
  326. result = await self.get_many([key], **kwargs)
  327. return result[key]
  328. async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
  329. **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
  330. Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
  331. """
  332. Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
  333. :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
  334. :param sufficient_expiration_time: if the search finds a value that expires after this time,
  335. default = time of call, find any value that did not expire by the time of call
  336. If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
  337. :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
  338. :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
  339. :note: in order to check if get returned a value, please check if (expiration_time is None)
  340. """
  341. keys = tuple(keys)
  342. key_ids = [DHTID.generate(key) for key in keys]
  343. id_to_original_key = dict(zip(key_ids, keys))
  344. results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
  345. return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
  346. async def get_many_by_id(
  347. self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
  348. num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
  349. _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
  350. Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
  351. """
  352. Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
  353. :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
  354. :param sufficient_expiration_time: if the search finds a value that expires after this time,
  355. default = time of call, find any value that did not expire by the time of call
  356. If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
  357. :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
  358. :param num_workers: override for default num_workers, see traverse_dht num_workers param
  359. :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
  360. The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
  361. Note: canceling a future will stop search for the corresponding key
  362. :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
  363. :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
  364. :note: in order to check if get returned a value, please check (expiration_time is None)
  365. """
  366. sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
  367. beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
  368. num_workers = num_workers if num_workers is not None else self.num_workers
  369. search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
  370. key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
  371. if not _is_refresh: # if we're already refreshing cache, there's no need to trigger subsequent refreshes
  372. for key_id in key_ids:
  373. search_results[key_id].add_done_callback(self._trigger_cache_refresh)
  374. # if we have concurrent get request for some of the same keys, subscribe to their results
  375. if self.reuse_get_requests:
  376. for key_id, search_result in search_results.items():
  377. self.pending_get_requests[key_id].add(search_result)
  378. search_result.add_done_callback(self._reuse_finished_search_result)
  379. # stage 1: check for value in this node's local storage and cache
  380. for key_id in key_ids:
  381. search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
  382. if not _is_refresh:
  383. search_results[key_id].add_candidate(self.protocol.cache.get(key_id), source_node_id=self.node_id)
  384. # stage 2: traverse the DHT to get the remaining keys from remote peers
  385. unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
  386. node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all keys
  387. for key_id in unfinished_key_ids:
  388. node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
  389. key_id, self.protocol.bucket_size, exclude=self.node_id))
  390. # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
  391. async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
  392. queries = list(queries)
  393. response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
  394. if not response:
  395. return {query: ([], False) for query in queries}
  396. output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
  397. for key_id, (maybe_value_with_expiration, peers) in response.items():
  398. node_to_endpoint.update(peers)
  399. search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
  400. output[key_id] = tuple(peers.keys()), search_results[key_id].finished
  401. # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
  402. return output
  403. # V-- this function will be called exactly once when traverse_dht finishes search for a given key
  404. async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
  405. search_results[key_id].finish_search() # finish search whether or we found something
  406. self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
  407. asyncio.create_task(traverse_dht(
  408. queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size,
  409. num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
  410. get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
  411. found_callback=found_callback, await_all_tasks=False))
  412. if return_futures:
  413. return {key_id: search_result.future for key_id, search_result in search_results.items()}
  414. else:
  415. try:
  416. # note: this should be first time when we await something, there's no need to "try" the entire function
  417. return {key_id: await search_result.future for key_id, search_result in search_results.items()}
  418. except asyncio.CancelledError as e: # terminate remaining tasks ASAP
  419. for key_id, search_result in search_results.items():
  420. search_result.future.cancel()
  421. raise e
  422. def _reuse_finished_search_result(self, finished: _SearchState):
  423. pending_requests = self.pending_get_requests[finished.key_id]
  424. if finished.found_something:
  425. search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
  426. expiration_time_threshold = max(finished.expiration_time, finished.sufficient_expiration_time)
  427. # note: pending_requests is sorted in the order of descending sufficient_expiration_time
  428. while pending_requests and expiration_time_threshold >= pending_requests[-1].sufficient_expiration_time:
  429. pending_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
  430. pending_requests[-1].finish_search()
  431. pending_requests.pop()
  432. else:
  433. pending_requests.discard(finished)
  434. async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
  435. """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
  436. if endpoint in self.blacklist:
  437. return None
  438. response = await self.protocol.call_find(endpoint, keys)
  439. if response:
  440. self.blacklist.register_success(endpoint)
  441. return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
  442. for key, (maybe_value, nearest_peers) in response.items()}
  443. else:
  444. self.blacklist.register_failure(endpoint)
  445. return None
  446. def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
  447. return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
  448. def _trigger_cache_refresh(self, search: _SearchState):
  449. """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
  450. if search.found_something and search.source_node_id == self.node_id:
  451. if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
  452. self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
  453. def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
  454. """ Add key to a refresh queue, refresh at :refresh_time: or later """
  455. if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
  456. self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
  457. logger.debug("Spawned cache refresh task.")
  458. earliest_key, earliest_item = self.cache_refresh_queue.top()
  459. if earliest_item is None or refresh_time < earliest_item.expiration_time:
  460. self.cache_refresh_evt.set() # if we new element is now earliest, notify the cache queue
  461. self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
  462. async def _refresh_stale_cache_entries(self):
  463. """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
  464. while self.is_alive:
  465. while len(self.cache_refresh_queue) == 0:
  466. await self.cache_refresh_evt.wait()
  467. self.cache_refresh_evt.clear()
  468. key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
  469. try:
  470. # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
  471. time_to_wait = nearest_refresh_time - get_dht_time()
  472. await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait)
  473. # note: the line above will cause TimeoutError when we are ready to refresh cache
  474. self.cache_refresh_evt.clear() # no timeout error => someone added new entry to queue and ...
  475. continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first
  476. except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry
  477. # step 2: find all keys that we should already refresh and remove them from queue
  478. current_time = get_dht_time()
  479. keys_to_refresh = {key_id}
  480. max_expiration_time = nearest_refresh_time
  481. del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
  482. while self.cache_refresh_queue and len(keys_to_refresh) < self.chunk_size:
  483. key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
  484. if nearest_refresh_time > current_time:
  485. break
  486. del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
  487. keys_to_refresh.add(key_id)
  488. cached_item = self.protocol.cache.get(key_id)
  489. if cached_item is not None and cached_item.expiration_time > max_expiration_time:
  490. max_expiration_time = cached_item.expiration_time
  491. # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
  492. sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
  493. await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
  494. def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
  495. node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
  496. """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
  497. if search.found_something:
  498. _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
  499. _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
  500. if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
  501. if self.cache_locally or _is_refresh:
  502. self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
  503. if self.cache_nearest:
  504. num_cached_nodes = 0
  505. for node_id in nearest_nodes:
  506. if node_id == search.source_node_id:
  507. continue
  508. asyncio.create_task(self.protocol.call_store(
  509. node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
  510. in_cache=True))
  511. num_cached_nodes += 1
  512. if num_cached_nodes >= self.cache_nearest:
  513. break
  514. async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
  515. """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
  516. while self.is_alive and period is not None: # if None run once, otherwise run forever
  517. refresh_time = get_dht_time()
  518. staleness_threshold = refresh_time - period
  519. stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
  520. if bucket.last_updated < staleness_threshold]
  521. for bucket in stale_buckets:
  522. refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
  523. await self.find_nearest_nodes(refresh_id)
  524. await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
  525. @dataclass(init=True, repr=True, frozen=False, order=False)
  526. class _SearchState:
  527. """ A helper class that stores current-best GET results with metadata """
  528. key_id: DHTID
  529. sufficient_expiration_time: DHTExpiration
  530. binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
  531. expiration_time: Optional[DHTExpiration] = None # best expiration time so far
  532. source_node_id: Optional[DHTID] = None # node that gave us the value
  533. future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
  534. serializer: type(SerializerBase) = MSGPackSerializer
  535. def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
  536. source_node_id: Optional[DHTID]):
  537. if self.finished or candidate is None:
  538. return
  539. elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
  540. self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
  541. for subkey, subentry in candidate.value.items():
  542. self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
  543. elif candidate.expiration_time > (self.expiration_time or float('-inf')):
  544. self.binary_value = candidate.value
  545. if candidate.expiration_time > (self.expiration_time or float('-inf')):
  546. self.expiration_time = candidate.expiration_time
  547. self.source_node_id = source_node_id
  548. if self.expiration_time >= self.sufficient_expiration_time:
  549. self.finish_search()
  550. def add_done_callback(self, callback: Callable[[_SearchState], Any]):
  551. """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
  552. self.future.add_done_callback(lambda _future: callback(self))
  553. def finish_search(self):
  554. if self.future.done():
  555. return # either user cancelled our search or someone sent it before us. Nothing more to do here.
  556. elif not self.found_something:
  557. self.future.set_result(None)
  558. elif isinstance(self.binary_value, BinaryDHTValue):
  559. self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
  560. elif isinstance(self.binary_value, DictionaryDHTValue):
  561. dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
  562. for key, (value, item_expiration_time) in self.binary_value.items()}
  563. self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
  564. else:
  565. logger.error(f"Invalid value type: {type(self.binary_value)}")
  566. @property
  567. def found_something(self) -> bool:
  568. """ Whether or not we have found at least some value, regardless of its expiration time """
  569. return self.expiration_time is not None
  570. @property
  571. def finished(self) -> bool:
  572. return self.future.done()
  573. def __lt__(self, other: _SearchState):
  574. """ _SearchState instances will be sorted by their target expiration time """
  575. return self.sufficient_expiration_time < other.sufficient_expiration_time
  576. def __hash__(self):
  577. return hash(self.key_id)
  578. class Blacklist:
  579. """
  580. A temporary blacklist of non-responding peers with exponential backoff policy
  581. :param base_time: peers are suspended for this many seconds by default
  582. :param backoff_rate: suspension time increases by this factor after each successive failure
  583. """
  584. def __init__(self, base_time: float, backoff_rate: float, **kwargs):
  585. self.base_time, self.backoff = base_time, backoff_rate
  586. self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
  587. self.ban_counter = Counter()
  588. def register_failure(self, peer: Endpoint):
  589. """ peer failed to respond, add him to blacklist or increase his downtime """
  590. if peer not in self.banned_peers and self.base_time > 0:
  591. ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
  592. self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
  593. self.ban_counter[peer] += 1
  594. def register_success(self, peer):
  595. """ peer responded successfully, remove him from blacklist and reset his ban time """
  596. del self.banned_peers[peer], self.ban_counter[peer]
  597. def __contains__(self, peer: Endpoint) -> bool:
  598. return peer in self.banned_peers
  599. def __repr__(self):
  600. return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
  601. f"banned_peers={len(self.banned_peers)})"
  602. def clear(self):
  603. self.banned_peers.clear()
  604. self.ban_counter.clear()
  605. class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
  606. """ a queue of keys scheduled for refresh in future, used in DHTNode """
  607. frozen = True