123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701 |
- from __future__ import annotations
- import asyncio
- import random
- from collections import defaultdict, Counter
- from dataclasses import dataclass, field
- from functools import partial
- from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
- from sortedcontainers import SortedSet
- from hivemind.dht.protocol import DHTProtocol
- from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
- from hivemind.dht.storage import DictionaryDHTValue
- from hivemind.dht.traverse import traverse_dht
- from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
- from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
- logger = get_logger(__name__)
- class DHTNode:
- """
- Asyncio-based class that represents one DHT participant. Created via await DHTNode.create(...)
- Each DHTNode has an identifier, a local storage and access too other nodes via DHTProtocol.
- :note: Hivemind DHT is optimized to store a lot of temporary metadata that is regularly updated.
- For example, expert heartbeat emitted by a hivemind.Server responsible for that expert.
- Such metadata does not require regular maintenance by peers or persistence on shutdown.
- Instead, DHTNode is designed to rapidly send bulk data and resolve conflicts.
- Every (key, value) pair in this DHT has an expiration time - float computed as get_dht_time() (UnixTime by default)
- DHT nodes always prefer values with higher expiration time and may delete any value past its expiration.
- Similar to Kademlia RPC protocol, hivemind DHT has 3 RPCs:
- * ping - request peer's identifier and update routing table (same as Kademlia PING RPC)
- * store - send several (key, value, expiration_time) pairs to the same peer (like Kademlia STORE, but in bulk)
- * find - request one or several keys, get values and expiration (if peer finds it locally) and :bucket_size: of
- nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
- This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
- A DHTNode follows the following contract:
- - when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
- IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
- - when requested to store(key: value, expiration_time), a node must store (key => value) at until expiration time
- or until DHTNode gets the same key with greater expiration time. If a node is asked to store a key but it already
- has the same key with newer expiration, store will be rejected. Store returns True if accepted, False if rejected;
- - when requested to store(key: value, expiration_time, subkey=subkey), adds a sub-key to a dictionary value type.
- Dictionary values can have multiple sub-keys stored by different peers with individual expiration times. A subkey
- will be accepted to a dictionary either if there is no such sub-key or if new subkey's expiration is later than
- previous expiration under that subkey. See DHTProtocol.call_store for details.
- DHTNode also features several (optional) caching policies:
- - cache_locally: after GET, store the result in node's own local cache
- - cache_nearest: after GET, send the result to this many nearest nodes that don't have that value yet (see Kademlia)
- - cache_on_store: after STORE, either save or remove that key from node's own cache depending on store status
- - cache_refresh_before_expiry: if a value in cache was used and is about to expire, try to GET it this many seconds
- before expiration. The motivation here is that some frequent keys should be always kept in cache to avoid latency.
- - reuse_get_requests: if there are several concurrent GET requests, when one request finishes, DHTNode will attempt
- to reuse the result of this GET request for other requests with the same key. Useful for batch-parallel requests.
- """
- # fmt:off
- node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
- chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
- cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
- cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
- blacklist: Blacklist
- # fmt:on
- @classmethod
- async def create(
- cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
- bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
- wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
- cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
- cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
- blacklist_time: float = 5.0, backoff_rate: float = 2.0,
- listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
- validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
- """
- :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
- :param initial_peers: connects to these peers to populate routing table, defaults to no peers
- :param bucket_size: max number of nodes in one k-bucket (k). Trying to add {k+1}st node will cause a bucket to
- either split in two buckets along the midpoint or reject the new node (but still save it as a replacement)
- Recommended value: k is chosen s.t. any given k nodes are very unlikely to all fail after staleness_timeout
- :param num_replicas: number of nearest nodes that will be asked to store a given key, default = bucket_size (≈k)
- :param depth_modulo: split full k-bucket if it contains root OR up to the nearest multiple of this value (≈b)
- :param parallel_rpc: maximum number of concurrent outgoing RPC requests emitted by DHTProtocol
- Reduce this value if your RPC requests register no response despite the peer sending the response.
- :param wait_timeout: a kademlia rpc request is deemed lost if we did not receive a reply in this many seconds
- :param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
- if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
- :param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
- :param cache_locally: if True, caches all values (stored or found) in a node-local cache
- :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
- :param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
- nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
- :param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
- :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
- if they are accessed this many seconds before expiration time.
- :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
- all concurrent get requests for the same key will reuse the procedure that is currently in progress
- :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
- :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
- :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
- :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
- :param validate: if True, use initial peers to validate that this node is accessible and synchronized
- :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
- :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
- if False, this node will refuse any incoming request, effectively being only a "client"
- :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
- :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
- :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
- see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
- :param kwargs: extra parameters used in grpc.aio.server
- """
- self = cls(_initialized_with_create=True)
- self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
- self.num_replicas, self.num_workers, self.chunk_size = num_replicas, num_workers, chunk_size
- self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh
- self.reuse_get_requests = reuse_get_requests
- self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
- # caching policy
- self.refresh_timeout = refresh_timeout
- self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
- self.cache_refresh_before_expiry = cache_refresh_before_expiry
- self.blacklist = Blacklist(blacklist_time, backoff_rate)
- self.cache_refresh_queue = CacheRefreshQueue()
- self.cache_refresh_evt = asyncio.Event()
- self.cache_refresh_task = None
- self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
- parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
- self.port = self.protocol.port
- if initial_peers:
- # stage 1: ping initial_peers, add each other to the routing table
- bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
- start_time = get_dht_time()
- ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
- for peer in initial_peers)
- finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
- # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
- if unfinished_pings:
- finished_in_time, stragglers = await asyncio.wait(
- unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
- for straggler in stragglers:
- straggler.cancel()
- finished_pings |= finished_in_time
- if not finished_pings or all(ping.result() is None for ping in finished_pings):
- logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
- if strict:
- for task in asyncio.as_completed(finished_pings):
- await task # propagate exceptions
- # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
- # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
- # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
- await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
- asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
- return_when=asyncio.FIRST_COMPLETED)
- if self.refresh_timeout is not None:
- asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
- return self
- def __init__(self, *, _initialized_with_create=False):
- """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
- assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
- super().__init__()
- async def shutdown(self, timeout=None):
- """ Process existing requests, close all connections and stop the server """
- self.is_alive = False
- if self.protocol.server:
- await self.protocol.shutdown(timeout)
- async def find_nearest_nodes(
- self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
- num_workers: Optional[int] = None, node_to_endpoint: Optional[Dict[DHTID, Endpoint]] = None,
- exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, Endpoint]]:
- """
- :param queries: find k nearest nodes for each of these DHTIDs
- :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
- :param beam_size: replacement for self.beam_size, see traverse_dht beam_size param
- :param num_workers: replacement for self.num_workers, see traverse_dht num_workers param
- :param node_to_endpoint: if specified, uses this dict[node_id => endpoint] as initial peers
- :param exclude_self: if True, nearest nodes will not contain self.node_id (default = use local peers)
- :param kwargs: additional params passed to traverse_dht
- :returns: for every query, return nearest peers ordered dict[peer DHTID -> network Endpoint], nearest-first
- """
- queries = tuple(queries)
- k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
- num_workers = num_workers if num_workers is not None else self.num_workers
- beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
- if k_nearest > beam_size:
- logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
- if node_to_endpoint is None:
- node_to_endpoint: Dict[DHTID, Endpoint] = dict()
- for query in queries:
- neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
- node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
- async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
- response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
- if not response:
- return {query: ([], False) for query in queries}
- output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
- for query, (_, peers) in response.items():
- node_to_endpoint.update(peers)
- output[query] = tuple(peers.keys()), False # False means "do not interrupt search"
- return output
- nearest_nodes_per_query, visited_nodes = await traverse_dht(
- queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
- queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
- visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
- nearest_nodes_with_endpoints = {}
- for query, nearest_nodes in nearest_nodes_per_query.items():
- if not exclude_self:
- nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
- node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
- nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
- return nearest_nodes_with_endpoints
- async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
- subkey: Optional[Subkey] = None, **kwargs) -> bool:
- """
- Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
- :note: store is a simplified interface to store_many, all kwargs are be forwarded there
- :returns: True if store succeeds, False if it fails (due to no response or newer value)
- """
- store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
- return store_ok[(key, subkey) if subkey is not None else key]
- async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
- expiration_time: Union[DHTExpiration, List[DHTExpiration]],
- subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
- exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
- """
- Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
- :param keys: arbitrary serializable keys associated with each value
- :param values: serializable "payload" for each key
- :param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
- :param subkeys: an optional list of same shape as keys. If specified, this
- :param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
- :param exclude_self: if True, never store value locally even if you are one of the nearest nodes
- :note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
- :param await_all_replicas: if False, this function returns after first store_ok and proceeds in background
- if True, the function will wait for num_replicas successful stores or running out of beam_size nodes
- :returns: for each key: True if store succeeds, False if it fails (due to no response or newer value)
- """
- if isinstance(expiration_time, DHTExpiration):
- expiration_time = [expiration_time] * len(keys)
- if subkeys is None:
- subkeys = [None] * len(keys)
- assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
- "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
- key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
- for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
- key_id_to_data[DHTID.generate(source=key)].append((key, subkey, value, expiration))
- unfinished_key_ids = set(key_id_to_data.keys()) # use this set to ensure that each store request is finished
- store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)} # outputs, updated during search
- store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
- # pre-populate node_to_endpoint
- node_to_endpoint: Dict[DHTID, Endpoint] = dict()
- for key_id in unfinished_key_ids:
- node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
- key_id, self.protocol.bucket_size, exclude=self.node_id))
- async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
- """ This will be called once per key when find_nearest_nodes is done for a particular node """
- # note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
- assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
- assert self.node_id not in nearest_nodes
- unfinished_key_ids.remove(key_id)
- # ensure k nodes stored the value, optionally include self.node_id as a candidate
- num_successful_stores = 0
- pending_store_tasks = set()
- store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
- key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest
- [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
- binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
- while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
- while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
- node_id: DHTID = store_candidates.pop() # nearest untried candidate
- if node_id == self.node_id:
- num_successful_stores += 1
- for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
- store_ok[original_key, subkey] = self.protocol.storage.store(
- key_id, value, expiration_time, subkey=subkey)
- if not await_all_replicas:
- store_finished_events[original_key, subkey].set()
- else:
- pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
- node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
- expiration_time=current_expirations, subkeys=current_subkeys)))
- # await nearest task. If it fails, dispatch more on the next iteration
- if pending_store_tasks:
- finished_store_tasks, pending_store_tasks = await asyncio.wait(
- pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
- for task in finished_store_tasks:
- if task.result() is not None:
- num_successful_stores += 1
- for subkey, store_status in zip(current_subkeys, task.result()):
- store_ok[original_key, subkey] = store_status
- if not await_all_replicas:
- store_finished_events[original_key, subkey].set()
- if self.cache_on_store:
- self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
- store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
- for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
- store_finished_events[original_key, subkey].set()
- store_task = asyncio.create_task(self.find_nearest_nodes(
- queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
- found_callback=on_found, exclude_self=exclude_self, **kwargs))
- try:
- await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
- assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
- return {(key, subkey) if subkey is not None else key: status or False for (key, subkey), status in store_ok.items()}
- except asyncio.CancelledError as e:
- store_task.cancel()
- raise e
- def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
- expirations: List[DHTExpiration], store_ok: List[bool]):
- """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
- store_succeeded = any(store_ok)
- is_dictionary = any(subkey is not None for subkey in subkeys)
- if store_succeeded and not is_dictionary: # stored a new regular value, cache it!
- stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
- self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
- elif not store_succeeded and not is_dictionary: # store rejected, check if local cache is also obsolete
- rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
- if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration: # cache would be rejected
- self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
- elif is_dictionary and key_id in self.protocol.cache: # there can be other keys and we should update
- for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
- self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
- self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
- async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:
- """
- Search for a key across DHT and return either first or latest entry (if found).
- :param key: same key as in node.store(...)
- :param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
- :param kwargs: parameters forwarded to get_many_by_id
- :returns: (value, expiration time); if value was not found, returns None
- """
- if latest:
- kwargs["sufficient_expiration_time"] = float('inf')
- result = await self.get_many([key], **kwargs)
- return result[key]
- async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
- **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
- Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
- """
- Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
- :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
- :param sufficient_expiration_time: if the search finds a value that expires after this time,
- default = time of call, find any value that did not expire by the time of call
- If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
- :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
- :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
- :note: in order to check if get returned a value, please check if (expiration_time is None)
- """
- keys = tuple(keys)
- key_ids = [DHTID.generate(key) for key in keys]
- id_to_original_key = dict(zip(key_ids, keys))
- results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
- return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
- async def get_many_by_id(
- self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
- num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
- _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
- Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
- """
- Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
- :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
- :param sufficient_expiration_time: if the search finds a value that expires after this time,
- default = time of call, find any value that did not expire by the time of call
- If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
- :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
- :param num_workers: override for default num_workers, see traverse_dht num_workers param
- :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
- The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
- Note: canceling a future will stop search for the corresponding key
- :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
- :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
- :note: in order to check if get returned a value, please check (expiration_time is None)
- """
- sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
- beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
- num_workers = num_workers if num_workers is not None else self.num_workers
- search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
- key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
- if not _is_refresh: # if we're already refreshing cache, there's no need to trigger subsequent refreshes
- for key_id in key_ids:
- search_results[key_id].add_done_callback(self._trigger_cache_refresh)
- # if we have concurrent get request for some of the same keys, subscribe to their results
- if self.reuse_get_requests:
- for key_id, search_result in search_results.items():
- self.pending_get_requests[key_id].add(search_result)
- search_result.add_done_callback(self._reuse_finished_search_result)
- # stage 1: check for value in this node's local storage and cache
- for key_id in key_ids:
- search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
- if not _is_refresh:
- search_results[key_id].add_candidate(self.protocol.cache.get(key_id), source_node_id=self.node_id)
- # stage 2: traverse the DHT to get the remaining keys from remote peers
- unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
- node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all keys
- for key_id in unfinished_key_ids:
- node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
- key_id, self.protocol.bucket_size, exclude=self.node_id))
- # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
- async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
- queries = list(queries)
- response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
- if not response:
- return {query: ([], False) for query in queries}
- output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
- for key_id, (maybe_value_with_expiration, peers) in response.items():
- node_to_endpoint.update(peers)
- search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
- output[key_id] = tuple(peers.keys()), search_results[key_id].finished
- # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
- return output
- # V-- this function will be called exactly once when traverse_dht finishes search for a given key
- async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
- search_results[key_id].finish_search() # finish search whether or we found something
- self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
- asyncio.create_task(traverse_dht(
- queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size,
- num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
- get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
- found_callback=found_callback, await_all_tasks=False))
- if return_futures:
- return {key_id: search_result.future for key_id, search_result in search_results.items()}
- else:
- try:
- # note: this should be first time when we await something, there's no need to "try" the entire function
- return {key_id: await search_result.future for key_id, search_result in search_results.items()}
- except asyncio.CancelledError as e: # terminate remaining tasks ASAP
- for key_id, search_result in search_results.items():
- search_result.future.cancel()
- raise e
- def _reuse_finished_search_result(self, finished: _SearchState):
- pending_requests = self.pending_get_requests[finished.key_id]
- if finished.found_something:
- search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
- expiration_time_threshold = max(finished.expiration_time, finished.sufficient_expiration_time)
- # note: pending_requests is sorted in the order of descending sufficient_expiration_time
- while pending_requests and expiration_time_threshold >= pending_requests[-1].sufficient_expiration_time:
- pending_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
- pending_requests[-1].finish_search()
- pending_requests.pop()
- else:
- pending_requests.discard(finished)
- async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
- """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
- if endpoint in self.blacklist:
- return None
- response = await self.protocol.call_find(endpoint, keys)
- if response:
- self.blacklist.register_success(endpoint)
- return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
- for key, (maybe_value, nearest_peers) in response.items()}
- else:
- self.blacklist.register_failure(endpoint)
- return None
- def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
- return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
- def _trigger_cache_refresh(self, search: _SearchState):
- """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
- if search.found_something and search.source_node_id == self.node_id:
- if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
- self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
- def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
- """ Add key to a refresh queue, refresh at :refresh_time: or later """
- if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
- self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
- logger.debug("Spawned cache refresh task.")
- earliest_key, earliest_item = self.cache_refresh_queue.top()
- if earliest_item is None or refresh_time < earliest_item.expiration_time:
- self.cache_refresh_evt.set() # if we new element is now earliest, notify the cache queue
- self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
- async def _refresh_stale_cache_entries(self):
- """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
- while self.is_alive:
- while len(self.cache_refresh_queue) == 0:
- await self.cache_refresh_evt.wait()
- self.cache_refresh_evt.clear()
- key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
- try:
- # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
- time_to_wait = nearest_refresh_time - get_dht_time()
- await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait)
- # note: the line above will cause TimeoutError when we are ready to refresh cache
- self.cache_refresh_evt.clear() # no timeout error => someone added new entry to queue and ...
- continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first
- except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry
- # step 2: find all keys that we should already refresh and remove them from queue
- current_time = get_dht_time()
- keys_to_refresh = {key_id}
- max_expiration_time = nearest_refresh_time
- del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
- while self.cache_refresh_queue and len(keys_to_refresh) < self.chunk_size:
- key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
- if nearest_refresh_time > current_time:
- break
- del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
- keys_to_refresh.add(key_id)
- cached_item = self.protocol.cache.get(key_id)
- if cached_item is not None and cached_item.expiration_time > max_expiration_time:
- max_expiration_time = cached_item.expiration_time
- # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
- sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
- await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
- def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
- node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
- """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
- if search.found_something:
- _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
- _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
- if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
- if self.cache_locally or _is_refresh:
- self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
- if self.cache_nearest:
- num_cached_nodes = 0
- for node_id in nearest_nodes:
- if node_id == search.source_node_id:
- continue
- asyncio.create_task(self.protocol.call_store(
- node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
- in_cache=True))
- num_cached_nodes += 1
- if num_cached_nodes >= self.cache_nearest:
- break
- async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
- """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
- while self.is_alive and period is not None: # if None run once, otherwise run forever
- refresh_time = get_dht_time()
- staleness_threshold = refresh_time - period
- stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
- if bucket.last_updated < staleness_threshold]
- for bucket in stale_buckets:
- refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
- await self.find_nearest_nodes(refresh_id)
- await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
- @dataclass(init=True, repr=True, frozen=False, order=False)
- class _SearchState:
- """ A helper class that stores current-best GET results with metadata """
- key_id: DHTID
- sufficient_expiration_time: DHTExpiration
- binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
- expiration_time: Optional[DHTExpiration] = None # best expiration time so far
- source_node_id: Optional[DHTID] = None # node that gave us the value
- future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
- serializer: type(SerializerBase) = MSGPackSerializer
- def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
- source_node_id: Optional[DHTID]):
- if self.finished or candidate is None:
- return
- elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
- self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
- for subkey, subentry in candidate.value.items():
- self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
- elif candidate.expiration_time > (self.expiration_time or float('-inf')):
- self.binary_value = candidate.value
- if candidate.expiration_time > (self.expiration_time or float('-inf')):
- self.expiration_time = candidate.expiration_time
- self.source_node_id = source_node_id
- if self.expiration_time >= self.sufficient_expiration_time:
- self.finish_search()
- def add_done_callback(self, callback: Callable[[_SearchState], Any]):
- """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
- self.future.add_done_callback(lambda _future: callback(self))
- def finish_search(self):
- if self.future.done():
- return # either user cancelled our search or someone sent it before us. Nothing more to do here.
- elif not self.found_something:
- self.future.set_result(None)
- elif isinstance(self.binary_value, BinaryDHTValue):
- self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
- elif isinstance(self.binary_value, DictionaryDHTValue):
- dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
- for key, (value, item_expiration_time) in self.binary_value.items()}
- self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
- else:
- logger.error(f"Invalid value type: {type(self.binary_value)}")
- @property
- def found_something(self) -> bool:
- """ Whether or not we have found at least some value, regardless of its expiration time """
- return self.expiration_time is not None
- @property
- def finished(self) -> bool:
- return self.future.done()
- def __lt__(self, other: _SearchState):
- """ _SearchState instances will be sorted by their target expiration time """
- return self.sufficient_expiration_time < other.sufficient_expiration_time
- def __hash__(self):
- return hash(self.key_id)
- class Blacklist:
- """
- A temporary blacklist of non-responding peers with exponential backoff policy
- :param base_time: peers are suspended for this many seconds by default
- :param backoff_rate: suspension time increases by this factor after each successive failure
- """
- def __init__(self, base_time: float, backoff_rate: float, **kwargs):
- self.base_time, self.backoff = base_time, backoff_rate
- self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
- self.ban_counter = Counter()
- def register_failure(self, peer: Endpoint):
- """ peer failed to respond, add him to blacklist or increase his downtime """
- if peer not in self.banned_peers and self.base_time > 0:
- ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
- self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
- self.ban_counter[peer] += 1
- def register_success(self, peer):
- """ peer responded successfully, remove him from blacklist and reset his ban time """
- del self.banned_peers[peer], self.ban_counter[peer]
- def __contains__(self, peer: Endpoint) -> bool:
- return peer in self.banned_peers
- def __repr__(self):
- return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
- f"banned_peers={len(self.banned_peers)})"
- def clear(self):
- self.banned_peers.clear()
- self.ban_counter.clear()
- class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
- """ a queue of keys scheduled for refresh in future, used in DHTNode """
- frozen = True
|