|
@@ -1,15 +1,21 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
+
|
|
|
import random
|
|
|
-from collections import namedtuple
|
|
|
-from typing import Optional, Tuple, List, Dict, Collection, Union, Set
|
|
|
+from collections import defaultdict
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
|
|
|
+from sortedcontainers import SortedList
|
|
|
+from functools import partial
|
|
|
from warnings import warn
|
|
|
|
|
|
-from hivemind.dht.protocol import DHTProtocol
|
|
|
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue
|
|
|
+from hivemind.dht.protocol import DHTProtocol, LocalStorage
|
|
|
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
|
|
|
from hivemind.dht.traverse import traverse_dht
|
|
|
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer
|
|
|
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
class DHTNode:
|
|
@@ -45,8 +51,10 @@ class DHTNode:
|
|
|
|
|
|
"""
|
|
|
# fmt:off
|
|
|
- node_id: DHTID; port: int; num_replicas: int; cache_locally: bool; cache_nearest: int; num_workers: int
|
|
|
- refresh_timeout: float; protocol: DHTProtocol
|
|
|
+ node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
|
|
|
+ refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
|
|
|
+ cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
|
|
|
+ reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
|
|
|
serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
|
|
|
# fmt:on
|
|
|
|
|
@@ -55,8 +63,9 @@ class DHTNode:
|
|
|
cls, node_id: Optional[DHTID] = None, initial_peers: List[Endpoint] = (),
|
|
|
bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
|
|
|
wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
|
|
|
- num_workers: int = 1, cache_locally: bool = True, cache_nearest: int = 1, cache_size=None,
|
|
|
- listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
|
|
|
+ cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
|
|
|
+ reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
|
|
|
+ listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
|
|
|
"""
|
|
|
:param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
|
|
|
:param initial_peers: connects to these peers to populate routing table, defaults to no peers
|
|
@@ -71,11 +80,15 @@ class DHTNode:
|
|
|
:param refresh_timeout: refresh buckets if no node from that bucket was updated in this many seconds
|
|
|
if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
|
|
|
:param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
|
|
|
- :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
|
|
|
:param cache_locally: if True, caches all values (stored or found) in a node-local cache
|
|
|
:param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
|
|
|
nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
|
|
|
:param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
|
|
|
+ :param cache_refresh_before_expiry: if nonzero, refreshes locally cached values
|
|
|
+ if they are accessed this many seconds before expiration time.
|
|
|
+ :param reuse_get_requests: if True, DHTNode allows only one traverse_dht procedure for every key
|
|
|
+ all concurrent get requests for the same key will reuse the procedure that is currently in progress
|
|
|
+ :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
|
|
|
:param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
|
|
|
if False, this node will refuse any incoming request, effectively being only a "client"
|
|
|
:param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
|
|
@@ -83,11 +96,26 @@ class DHTNode:
|
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
|
:param kwargs: extra parameters used in grpc.aio.server
|
|
|
"""
|
|
|
+ if cache_refresh_before_expiry > 0 and not cache_locally:
|
|
|
+ logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
|
|
|
+ " warning, please specify cache_refresh_before_expiry=0")
|
|
|
+
|
|
|
self = cls(_initialized_with_create=True)
|
|
|
self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
|
|
|
self.num_replicas, self.num_workers = num_replicas, num_workers
|
|
|
- self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
|
|
|
+ self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh
|
|
|
+
|
|
|
+ self.reuse_get_requests = reuse_get_requests
|
|
|
+ self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
|
|
|
+
|
|
|
+ # caching policy
|
|
|
self.refresh_timeout = refresh_timeout
|
|
|
+ self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
|
|
|
+ self.cache_refresh_before_expiry = cache_refresh_before_expiry
|
|
|
+ self.cache_refresh_queue = LocalStorage()
|
|
|
+ self.cache_refresh_available = asyncio.Event()
|
|
|
+ if cache_refresh_before_expiry:
|
|
|
+ asyncio.create_task(self._refresh_stale_cache_entries())
|
|
|
|
|
|
self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
|
|
|
parallel_rpc, cache_size, listen, listen_on, **kwargs)
|
|
@@ -129,7 +157,9 @@ class DHTNode:
|
|
|
|
|
|
async def shutdown(self, timeout=None):
|
|
|
""" Process existing requests, close all connections and stop the server """
|
|
|
- await self.protocol.shutdown(timeout)
|
|
|
+ self.is_alive = False
|
|
|
+ if self.protocol.server:
|
|
|
+ await self.protocol.shutdown(timeout)
|
|
|
|
|
|
async def find_nearest_nodes(
|
|
|
self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
|
|
@@ -157,15 +187,15 @@ class DHTNode:
|
|
|
node_to_endpoint.update(
|
|
|
self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
|
|
|
|
|
|
- async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
|
|
|
+ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
|
|
|
response = await self.protocol.call_find(node_to_endpoint[peer], queries)
|
|
|
if not response:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
|
- output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
|
|
|
+ output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
|
|
|
for query, (_, _, peers) in response.items():
|
|
|
node_to_endpoint.update(peers)
|
|
|
- output[query] = list(peers.keys()), False # False means "do not interrupt search"
|
|
|
+ output[query] = tuple(peers.keys()), False # False means "do not interrupt search"
|
|
|
return output
|
|
|
|
|
|
nearest_nodes_per_query, visited_nodes = await traverse_dht(
|
|
@@ -289,7 +319,7 @@ class DHTNode:
|
|
|
Search for a key across DHT and return either first or latest entry.
|
|
|
:param key: same key as in node.store(...)
|
|
|
:param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
|
|
|
- :param kwargs: parameters forwarded to get_many
|
|
|
+ :param kwargs: parameters forwarded to get_many_by_id
|
|
|
:returns: (value, expiration time); if value was not found, returns (None, None)
|
|
|
"""
|
|
|
if latest:
|
|
@@ -297,100 +327,190 @@ class DHTNode:
|
|
|
result = await self.get_many([key])
|
|
|
return result[key]
|
|
|
|
|
|
- async def get_many(
|
|
|
- self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
- num_workers: Optional[int] = None, beam_size: Optional[int] = None
|
|
|
- ) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]:
|
|
|
+ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
+ Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
"""
|
|
|
+ Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
|
|
|
+
|
|
|
:param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
|
|
|
+ :param sufficient_expiration_time: if the search finds a value that expires after this time,
|
|
|
+ default = time of call, find any value that did not expire by the time of call
|
|
|
+ If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
|
|
|
+ :param kwargs: for full list of parameters, see DHTNode.get_many_by_id
|
|
|
+ :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
|
|
|
+ :note: in order to check if get returned a value, please check if (expiration_time is None)
|
|
|
+ """
|
|
|
+ keys = tuple(keys)
|
|
|
+ key_ids = [DHTID.generate(key) for key in keys]
|
|
|
+ id_to_original_key = dict(zip(key_ids, keys))
|
|
|
+ results_by_id = await self.get_many_by_id(key_ids, sufficient_expiration_time, **kwargs)
|
|
|
+ return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
|
|
|
+
|
|
|
+ async def get_many_by_id(
|
|
|
+ self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
|
|
|
+ _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
+ Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
+ """
|
|
|
+ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
|
|
|
+
|
|
|
+ :param key_ids: traverse the DHT and find the value for each of these keys (or (None, None) if not key found)
|
|
|
:param sufficient_expiration_time: if the search finds a value that expires after this time,
|
|
|
default = time of call, find any value that did not expire by the time of call
|
|
|
If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration
|
|
|
:param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size
|
|
|
:param num_workers: override for default num_workers, see traverse_dht num_workers param
|
|
|
- :returns: for each key: value and its expiration time. If nothing is found , returns (None, None) for that key
|
|
|
+ :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
|
|
|
+ The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
|
|
|
+ Note: canceling a future will stop search for the corresponding key
|
|
|
+ :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
|
|
|
+ :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
|
|
|
:note: in order to check if get returned a value, please check (expiration_time is None)
|
|
|
"""
|
|
|
- key_ids = [DHTID.generate(key) for key in keys]
|
|
|
- id_to_original_key = dict(zip(key_ids, keys))
|
|
|
sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
|
|
|
beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
|
|
|
num_workers = num_workers if num_workers is not None else self.num_workers
|
|
|
+ search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
|
|
|
+ key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
|
|
|
|
|
|
- # search metadata
|
|
|
- unfinished_key_ids = set(key_ids) # track key ids for which the search is not terminated
|
|
|
- node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all queries
|
|
|
+ if _refresh_cache:
|
|
|
+ for key_id in key_ids:
|
|
|
+ search_results[key_id].add_done_callback(self._trigger_cache_refresh)
|
|
|
|
|
|
- SearchResult = namedtuple("SearchResult", ["binary_value", "expiration_time", "source_node_id"])
|
|
|
- latest_results = {key_id: SearchResult(b'', -float('inf'), None) for key_id in key_ids}
|
|
|
+ # if we have concurrent get request for some of the same keys, subscribe to their results
|
|
|
+ if self.reuse_get_requests:
|
|
|
+ for key_id, search_result in search_results.items():
|
|
|
+ self.pending_get_requests[key_id].add(search_result)
|
|
|
+ search_result.add_done_callback(self._reuse_finished_search_result)
|
|
|
|
|
|
- # stage 1: value can be stored in our local cache
|
|
|
+ # stage 1: check for value in this node's local storage and cache
|
|
|
for key_id in key_ids:
|
|
|
- maybe_value, maybe_expiration_time = self.protocol.storage.get(key_id)
|
|
|
- if maybe_expiration_time is None:
|
|
|
- maybe_value, maybe_expiration_time = self.protocol.cache.get(key_id)
|
|
|
- if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
|
|
|
- latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, self.node_id)
|
|
|
- if maybe_expiration_time >= sufficient_expiration_time:
|
|
|
- unfinished_key_ids.remove(key_id)
|
|
|
-
|
|
|
- # stage 2: traverse the DHT for any unfinished keys
|
|
|
+ search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
|
|
|
+ search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
|
|
|
+
|
|
|
+ # stage 2: traverse the DHT to get the remaining keys from remote peers
|
|
|
+ unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
|
|
|
+ node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all keys
|
|
|
for key_id in unfinished_key_ids:
|
|
|
node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
|
|
|
key_id, self.protocol.bucket_size, exclude=self.node_id))
|
|
|
|
|
|
- async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[List[DHTID], bool]]:
|
|
|
+ # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
|
|
|
+ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
|
|
|
queries = list(queries)
|
|
|
response = await self.protocol.call_find(node_to_endpoint[peer], queries)
|
|
|
if not response:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
|
- output: Dict[DHTID, Tuple[List[DHTID], bool]] = {}
|
|
|
- for key_id, (maybe_value, maybe_expiration_time, peers) in response.items():
|
|
|
+ output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
|
|
|
+ for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
|
|
|
node_to_endpoint.update(peers)
|
|
|
- if maybe_expiration_time is not None and maybe_expiration_time > latest_results[key_id].expiration_time:
|
|
|
- latest_results[key_id] = SearchResult(maybe_value, maybe_expiration_time, peer)
|
|
|
- should_interrupt = (latest_results[key_id].expiration_time >= sufficient_expiration_time)
|
|
|
- output[key_id] = list(peers.keys()), should_interrupt
|
|
|
+ search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
|
|
|
+ output[key_id] = tuple(peers.keys()), search_results[key_id].finished
|
|
|
+ # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
|
|
|
return output
|
|
|
|
|
|
- nearest_nodes_per_query, visited_nodes = await traverse_dht(
|
|
|
+ # V-- this function will be called exactly once when traverse_dht finishes search for a given key
|
|
|
+ async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
|
|
|
+ search_results[key_id].finish_search() # finish search whether or we found something
|
|
|
+ self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
|
|
|
+
|
|
|
+ asyncio.create_task(traverse_dht(
|
|
|
queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
|
|
|
beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5),
|
|
|
- get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids})
|
|
|
-
|
|
|
- # stage 3: cache any new results depending on caching parameters
|
|
|
- for key_id, nearest_nodes in nearest_nodes_per_query.items():
|
|
|
- latest_value_bytes, latest_expiration_time, latest_node_id = latest_results[key_id]
|
|
|
- should_cache = latest_expiration_time >= sufficient_expiration_time # if we found a newer value, cache it
|
|
|
- if should_cache and self.cache_locally:
|
|
|
- self.protocol.cache.store(key_id, latest_value_bytes, latest_expiration_time)
|
|
|
-
|
|
|
- if should_cache and self.cache_nearest:
|
|
|
- num_cached_nodes = 0
|
|
|
- for node_id in nearest_nodes:
|
|
|
- if node_id == latest_node_id:
|
|
|
- continue
|
|
|
- asyncio.create_task(self.protocol.call_store(
|
|
|
- node_to_endpoint[node_id], [key_id], [latest_value_bytes], [latest_expiration_time],
|
|
|
- in_cache=True))
|
|
|
- num_cached_nodes += 1
|
|
|
- if num_cached_nodes >= self.cache_nearest:
|
|
|
- break
|
|
|
-
|
|
|
- # stage 4: deserialize data and assemble function output
|
|
|
- find_result: Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = {}
|
|
|
- for key_id, (latest_value_bytes, latest_expiration_time, _) in latest_results.items():
|
|
|
- if latest_expiration_time != -float('inf'):
|
|
|
- latest_value = self.serializer.loads(latest_value_bytes)
|
|
|
- find_result[id_to_original_key[key_id]] = (latest_value, latest_expiration_time)
|
|
|
- else:
|
|
|
- find_result[id_to_original_key[key_id]] = None, None
|
|
|
- return find_result
|
|
|
+ get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
|
|
|
+ found_callback=found_callback, await_all_tasks=False))
|
|
|
+
|
|
|
+ if return_futures:
|
|
|
+ return {key_id: search_result.future for key_id, search_result in search_results.items()}
|
|
|
+ else:
|
|
|
+ try:
|
|
|
+ # note: this should be first time when we await something, there's no need to "try" the entire function
|
|
|
+ return {key_id: await search_result.future for key_id, search_result in search_results.items()}
|
|
|
+ except asyncio.CancelledError as e: # terminate remaining tasks ASAP
|
|
|
+ for key_id, search_result in search_results.items():
|
|
|
+ search_result.future.cancel()
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def _reuse_finished_search_result(self, finished: _IntermediateResult):
|
|
|
+ expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
|
|
|
+ concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
|
|
|
+ # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
|
|
|
+ while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
|
|
|
+ concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
|
|
|
+ source_node_id=finished.source_node_id)
|
|
|
+ concurrent_requests[-1].finish_search()
|
|
|
+ concurrent_requests.pop(-1)
|
|
|
+
|
|
|
+ def _trigger_cache_refresh(self, result: _IntermediateResult):
|
|
|
+ """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
|
|
|
+ if result.found_something and result.source_node_id == self.node_id:
|
|
|
+ with self.protocol.cache.freeze(): # do not clear outdated cache for now...
|
|
|
+ if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
|
|
|
+ previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
|
|
|
+ self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
|
|
|
+ if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
|
|
|
+ self.cache_refresh_available.set() # if we new element is now earliest, notify the cache queue
|
|
|
+
|
|
|
+ async def _refresh_stale_cache_entries(self):
|
|
|
+ """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
|
|
|
+ while self.is_alive:
|
|
|
+ with self.cache_refresh_queue.freeze():
|
|
|
+ while len(self.cache_refresh_queue) == 0:
|
|
|
+ await self.cache_refresh_available.wait()
|
|
|
+ self.cache_refresh_available.clear()
|
|
|
+ key_id, _, nearest_expiration = self.cache_refresh_queue.top()
|
|
|
+
|
|
|
+ try:
|
|
|
+ # step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
|
|
|
+ time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
|
|
|
+ await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
|
|
|
+ # note: the line above will cause TimeoutError when we are ready to refresh cache
|
|
|
+ self.cache_refresh_available.clear() # no timeout error => someone added new entry to queue and ...
|
|
|
+ continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first
|
|
|
+
|
|
|
+ except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry
|
|
|
+ # step 2: find all keys that we should already refresh and remove them from queue
|
|
|
+ with self.cache_refresh_queue.freeze():
|
|
|
+ keys_to_refresh = {key_id}
|
|
|
+ del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
+ while self.cache_refresh_queue:
|
|
|
+ key_id, _, nearest_expiration = self.cache_refresh_queue.top()
|
|
|
+ if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
|
|
|
+ break
|
|
|
+ del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
+ keys_to_refresh.add(key_id)
|
|
|
+
|
|
|
+ # step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
|
|
|
+ await self.get_many_by_id(
|
|
|
+ keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
|
|
|
+ _refresh_cache=False) # if we found value locally, we shouldn't trigger another refresh
|
|
|
+
|
|
|
+ def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
|
|
|
+ node_to_endpoint: Dict[DHTID, Endpoint]):
|
|
|
+ """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
|
|
|
+ if result.found_something:
|
|
|
+ previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
|
|
|
+ self.protocol.cache.get(result.key_id)[1] or -float('inf'))
|
|
|
+ if result.expiration_time > previous_expiration_time: # if this value has better expiration
|
|
|
+ if self.cache_locally:
|
|
|
+ self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
|
|
|
+ if self.cache_nearest:
|
|
|
+ num_cached_nodes = 0
|
|
|
+ for node_id in nearest_nodes:
|
|
|
+ if node_id == result.source_node_id:
|
|
|
+ continue
|
|
|
+ asyncio.create_task(self.protocol.call_store(
|
|
|
+ node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
|
|
|
+ in_cache=True))
|
|
|
+ num_cached_nodes += 1
|
|
|
+ if num_cached_nodes >= self.cache_nearest:
|
|
|
+ break
|
|
|
|
|
|
async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
|
|
|
""" Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
|
|
|
- while period is not None: # if None run once, otherwise run forever
|
|
|
+ while self.is_alive and period is not None: # if None run once, otherwise run forever
|
|
|
refresh_time = get_dht_time()
|
|
|
staleness_threshold = refresh_time - period
|
|
|
stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
|
|
@@ -400,3 +520,45 @@ class DHTNode:
|
|
|
await self.find_nearest_nodes(refresh_id)
|
|
|
|
|
|
await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time)))
|
|
|
+
|
|
|
+
|
|
|
+@dataclass(init=True, repr=True, frozen=False, order=False)
|
|
|
+class _IntermediateResult:
|
|
|
+ """ A helper class that stores current-best GET results with metadata """
|
|
|
+ key_id: DHTID
|
|
|
+ sufficient_expiration_time: DHTExpiration
|
|
|
+ binary_value: Optional[BinaryDHTValue] = None
|
|
|
+ expiration_time: Optional[DHTExpiration] = None # best expiration time so far
|
|
|
+ source_node_id: Optional[DHTID] = None # node that gave us the value
|
|
|
+ future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
|
|
|
+ serializer: type(SerializerBase) = MSGPackSerializer
|
|
|
+
|
|
|
+ def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
|
|
|
+ source_node_id: Optional[DHTID]):
|
|
|
+ if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
|
|
|
+ self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
|
|
|
+ if self.expiration_time >= self.sufficient_expiration_time:
|
|
|
+ self.finish_search()
|
|
|
+
|
|
|
+ def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
|
|
|
+ """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
|
|
|
+ self.future.add_done_callback(lambda _future: callback(self))
|
|
|
+
|
|
|
+ def finish_search(self):
|
|
|
+ if self.future.done():
|
|
|
+ return # either user cancelled our result or someone sent it before us. Nothing more to do here.
|
|
|
+ deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
|
|
|
+ self.future.set_result((deserialized_value, self.expiration_time))
|
|
|
+
|
|
|
+ @property
|
|
|
+ def found_something(self) -> bool:
|
|
|
+ """ Whether or not we have at least some result, regardless of its expiration time """
|
|
|
+ return self.expiration_time is not None
|
|
|
+
|
|
|
+ @property
|
|
|
+ def finished(self) -> bool:
|
|
|
+ return self.future.done()
|
|
|
+
|
|
|
+ def __lt__(self, other: _IntermediateResult):
|
|
|
+ """ _IntermediateResult instances will be sorted by their target expiration time """
|
|
|
+ return self.sufficient_expiration_time < other.sufficient_expiration_time
|