|
@@ -1,17 +1,18 @@
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
-
|
|
|
import random
|
|
|
from collections import defaultdict
|
|
|
from dataclasses import dataclass, field
|
|
|
-from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any, Iterable
|
|
|
-from sortedcontainers import SortedList
|
|
|
from functools import partial
|
|
|
+from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
|
|
|
from warnings import warn
|
|
|
|
|
|
-from hivemind.dht.protocol import DHTProtocol, LocalStorage
|
|
|
-from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue
|
|
|
+from sortedcontainers import SortedList
|
|
|
+
|
|
|
+from hivemind.dht.protocol import DHTProtocol
|
|
|
+from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
|
|
|
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
|
|
|
from hivemind.dht.traverse import traverse_dht
|
|
|
from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
|
|
|
|
|
@@ -39,7 +40,7 @@ class DHTNode:
|
|
|
nearest peers from recipient's routing table (ordered nearest-to-farthest, not including recipient itself)
|
|
|
This RPC is a mixture between Kademlia FIND_NODE and FIND_VALUE with multiple keys per call.
|
|
|
|
|
|
- Formally, DHTNode follows the following contract:
|
|
|
+ A DHTNode follows the following contract:
|
|
|
|
|
|
- when asked to get(key), a node must find and return a value with highest expiration time that it found across DHT
|
|
|
IF that time has not come yet. if expiration time is smaller than current get_dht_time(), node may return None;
|
|
@@ -53,9 +54,8 @@ class DHTNode:
|
|
|
# fmt:off
|
|
|
node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
|
|
|
refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
|
|
|
- cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage
|
|
|
- reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]]
|
|
|
- serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network
|
|
|
+ cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
|
|
|
+ cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
|
|
|
# fmt:on
|
|
|
|
|
|
@classmethod
|
|
@@ -64,8 +64,8 @@ class DHTNode:
|
|
|
bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
|
|
|
wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
|
|
|
cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
|
|
|
- reuse_get_requests: bool = True, num_workers: int = 1, listen: bool = True,
|
|
|
- listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
|
|
|
+ cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1,
|
|
|
+ listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
|
|
|
"""
|
|
|
:param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
|
|
|
:param initial_peers: connects to these peers to populate routing table, defaults to no peers
|
|
@@ -81,6 +81,7 @@ class DHTNode:
|
|
|
if staleness_timeout is None, DHTNode will not refresh stale buckets (which is usually okay)
|
|
|
:param bootstrap_timeout: after one of peers responds, await other peers for at most this many seconds
|
|
|
:param cache_locally: if True, caches all values (stored or found) in a node-local cache
|
|
|
+ :param cache_on_store: if True, update cache entries for a key after storing a new item for that key
|
|
|
:param cache_nearest: whenever DHTNode finds a value, it will also store (cache) this value on this many
|
|
|
nodes nearest nodes visited by search algorithm. Prefers nodes that are nearest to :key: but have no value yet
|
|
|
:param cache_size: if specified, local cache will store up to this many records (as in LRU cache)
|
|
@@ -96,10 +97,6 @@ class DHTNode:
|
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
|
:param kwargs: extra parameters used in grpc.aio.server
|
|
|
"""
|
|
|
- if cache_refresh_before_expiry > 0 and not cache_locally:
|
|
|
- logger.warning("If cache_locally is False, cache_refresh_before_expiry has no effect. To silence this"
|
|
|
- " warning, please specify cache_refresh_before_expiry=0")
|
|
|
-
|
|
|
self = cls(_initialized_with_create=True)
|
|
|
self.node_id = node_id = node_id if node_id is not None else DHTID.generate()
|
|
|
self.num_replicas, self.num_workers = num_replicas, num_workers
|
|
@@ -110,12 +107,11 @@ class DHTNode:
|
|
|
|
|
|
# caching policy
|
|
|
self.refresh_timeout = refresh_timeout
|
|
|
- self.cache_locally, self.cache_nearest = cache_locally, cache_nearest
|
|
|
+ self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
|
|
|
self.cache_refresh_before_expiry = cache_refresh_before_expiry
|
|
|
- self.cache_refresh_queue = LocalStorage()
|
|
|
- self.cache_refresh_available = asyncio.Event()
|
|
|
- if cache_refresh_before_expiry:
|
|
|
- asyncio.create_task(self._refresh_stale_cache_entries())
|
|
|
+ self.cache_refresh_queue = CacheRefreshQueue()
|
|
|
+ self.cache_refresh_evt = asyncio.Event()
|
|
|
+ self.cache_refresh_task = None
|
|
|
|
|
|
self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
|
|
|
parallel_rpc, cache_size, listen, listen_on, **kwargs)
|
|
@@ -211,25 +207,27 @@ class DHTNode:
|
|
|
nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
|
|
|
return nearest_nodes_with_endpoints
|
|
|
|
|
|
- async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
|
|
|
+ async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
|
|
|
+ subkey: Optional[Subkey] = None, **kwargs) -> bool:
|
|
|
"""
|
|
|
Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
|
|
|
-
|
|
|
:note: store is a simplified interface to store_many, all kwargs are be forwarded there
|
|
|
:returns: True if store succeeds, False if it fails (due to no response or newer value)
|
|
|
"""
|
|
|
- store_ok = await self.store_many([key], [value], [expiration_time], **kwargs)
|
|
|
- return store_ok[key]
|
|
|
+ store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
|
|
|
+ return store_ok[(key, subkey) if subkey is not None else key]
|
|
|
|
|
|
async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
|
|
|
expiration_time: Union[DHTExpiration, List[DHTExpiration]],
|
|
|
+ subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
|
|
|
exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
|
|
|
"""
|
|
|
- Traverse DHT to find up to best nodes to store multiple (key, value, expiration_time) pairs.
|
|
|
+ Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
|
|
|
|
|
|
:param keys: arbitrary serializable keys associated with each value
|
|
|
:param values: serializable "payload" for each key
|
|
|
:param expiration_time: either one expiration time for all keys or individual expiration times (see class doc)
|
|
|
+ :param subkeys: an optional list of same shape as keys. If specified, this
|
|
|
:param kwargs: any additional parameters passed to traverse_dht function (e.g. num workers)
|
|
|
:param exclude_self: if True, never store value locally even if you are one of the nearest nodes
|
|
|
:note: if exclude_self is True and self.cache_locally == True, value will still be __cached__ locally
|
|
@@ -239,24 +237,23 @@ class DHTNode:
|
|
|
"""
|
|
|
if isinstance(expiration_time, DHTExpiration):
|
|
|
expiration_time = [expiration_time] * len(keys)
|
|
|
- assert len(keys) == len(values) == len(expiration_time), "Number of keys, values and expiration doesn't match."
|
|
|
+ if subkeys is None or isinstance(subkeys, Subkey):
|
|
|
+ subkeys = [subkeys] * len(keys)
|
|
|
|
|
|
- key_ids = list(map(DHTID.generate, keys))
|
|
|
- id_to_original_key = dict(zip(key_ids, keys))
|
|
|
- binary_values_by_key_id = {key_id: self.serializer.dumps(value) for key_id, value in zip(key_ids, values)}
|
|
|
- expiration_by_key_id = {key_id: expiration_time for key_id, expiration_time in zip(key_ids, expiration_time)}
|
|
|
- unfinished_key_ids = set(key_ids) # we use this set to ensure that each store request is finished
|
|
|
+ assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
|
|
|
+ "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
|
|
|
|
|
|
- store_ok = {key: False for key in keys} # outputs, updated during search
|
|
|
- store_finished_events = {key: asyncio.Event() for key in keys}
|
|
|
+ key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
|
|
|
+ for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
|
|
|
+ key_id_to_data[DHTID.generate(source=key)].append((key, subkey, value, expiration))
|
|
|
|
|
|
- if self.cache_locally:
|
|
|
- for key_id in key_ids:
|
|
|
- self.protocol.cache.store(key_id, binary_values_by_key_id[key_id], expiration_by_key_id[key_id])
|
|
|
+ unfinished_key_ids = set(key_id_to_data.keys()) # use this set to ensure that each store request is finished
|
|
|
+ store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)} # outputs, updated during search
|
|
|
+ store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
|
|
|
|
|
|
# pre-populate node_to_endpoint
|
|
|
node_to_endpoint: Dict[DHTID, Endpoint] = dict()
|
|
|
- for key_id in key_ids:
|
|
|
+ for key_id in unfinished_key_ids:
|
|
|
node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
|
|
|
key_id, self.protocol.bucket_size, exclude=self.node_id))
|
|
|
|
|
@@ -272,48 +269,73 @@ class DHTNode:
|
|
|
pending_store_tasks = set()
|
|
|
store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
|
|
|
key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest
|
|
|
+ [original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
|
|
|
+ binary_values: List[bytes] = list(map(self.protocol.serializer.dumps, current_values))
|
|
|
|
|
|
while num_successful_stores < self.num_replicas and (store_candidates or pending_store_tasks):
|
|
|
- # spawn enough tasks to cover all replicas
|
|
|
while store_candidates and num_successful_stores + len(pending_store_tasks) < self.num_replicas:
|
|
|
node_id: DHTID = store_candidates.pop() # nearest untried candidate
|
|
|
+
|
|
|
if node_id == self.node_id:
|
|
|
- self.protocol.storage.store(key_id, binary_values_by_key_id[key_id],
|
|
|
- expiration_by_key_id[key_id])
|
|
|
- store_ok[id_to_original_key[key_id]] = True
|
|
|
num_successful_stores += 1
|
|
|
- if not await_all_replicas:
|
|
|
- store_finished_events[id_to_original_key[key_id]].set()
|
|
|
-
|
|
|
+ for subkey, value, expiration_time in zip(current_subkeys, binary_values, current_expirations):
|
|
|
+ store_ok[original_key, subkey] = self.protocol.storage.store(
|
|
|
+ key_id, value, expiration_time, subkey=subkey)
|
|
|
+ if not await_all_replicas:
|
|
|
+ store_finished_events[original_key, subkey].set()
|
|
|
else:
|
|
|
pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
|
|
|
- node_to_endpoint[node_id], [key_id], [binary_values_by_key_id[key_id]],
|
|
|
- [expiration_by_key_id[key_id]])))
|
|
|
+ node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
|
|
|
+ expiration_time=current_expirations, subkeys=current_subkeys)))
|
|
|
|
|
|
# await nearest task. If it fails, dispatch more on the next iteration
|
|
|
if pending_store_tasks:
|
|
|
finished_store_tasks, pending_store_tasks = await asyncio.wait(
|
|
|
pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
for task in finished_store_tasks:
|
|
|
- if task.result()[0]: # if store succeeded
|
|
|
- store_ok[id_to_original_key[key_id]] = True
|
|
|
+ if task.result() is not None:
|
|
|
num_successful_stores += 1
|
|
|
- if not await_all_replicas:
|
|
|
- store_finished_events[id_to_original_key[key_id]].set()
|
|
|
+ for subkey, store_status in zip(current_subkeys, task.result()):
|
|
|
+ store_ok[original_key, subkey] = store_status
|
|
|
+ if not await_all_replicas:
|
|
|
+ store_finished_events[original_key, subkey].set()
|
|
|
|
|
|
- store_finished_events[id_to_original_key[key_id]].set()
|
|
|
+ if self.cache_on_store:
|
|
|
+ self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
|
|
|
+ store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
|
|
|
+
|
|
|
+ for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
|
|
|
+ store_finished_events[original_key, subkey].set()
|
|
|
|
|
|
store_task = asyncio.create_task(self.find_nearest_nodes(
|
|
|
- queries=set(key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
|
|
|
+ queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
|
|
|
found_callback=on_found, exclude_self=exclude_self, **kwargs))
|
|
|
try:
|
|
|
await asyncio.wait([evt.wait() for evt in store_finished_events.values()]) # wait for items to be stored
|
|
|
assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
|
|
|
- return store_ok
|
|
|
+ return {(key, subkey) if subkey else key: status or False for (key, subkey), status in store_ok.items()}
|
|
|
except asyncio.CancelledError as e:
|
|
|
store_task.cancel()
|
|
|
raise e
|
|
|
|
|
|
+ def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
|
|
|
+ expirations: List[DHTExpiration], store_ok: List[bool]):
|
|
|
+ """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
|
|
|
+ store_succeeded = any(store_ok)
|
|
|
+ is_dictionary = any(subkey is not None for subkey in subkeys)
|
|
|
+ if store_succeeded and not is_dictionary: # stored a new regular value, cache it!
|
|
|
+ stored_value_bytes, stored_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
|
|
|
+ self.protocol.cache.store(key_id, stored_value_bytes, stored_expiration)
|
|
|
+ elif not store_succeeded and not is_dictionary: # store rejected, check if local cache is also obsolete
|
|
|
+ rejected_value, rejected_expiration = max(zip(binary_values, expirations), key=lambda p: p[1])
|
|
|
+ self.protocol.cache.store(key_id, rejected_value, rejected_expiration) # can still be better than cache
|
|
|
+ if (self.protocol.cache.get(key_id)[1] or float("inf")) <= rejected_expiration: # cache would be rejected
|
|
|
+ self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
|
|
|
+ else: # stored a dictionary (or failed to store), either way, there can be other keys and we should update
|
|
|
+ for subkey, stored_value_bytes, expiration_time in zip(subkeys, binary_values, expirations):
|
|
|
+ self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
|
|
|
+ self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
|
|
|
+
|
|
|
async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
|
|
|
"""
|
|
|
Search for a key across DHT and return either first or latest entry.
|
|
@@ -350,8 +372,8 @@ class DHTNode:
|
|
|
async def get_many_by_id(
|
|
|
self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
|
|
|
- _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
- Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
+ _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
+ Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
"""
|
|
|
Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
|
|
|
|
|
@@ -364,17 +386,17 @@ class DHTNode:
|
|
|
:param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework.
|
|
|
The algorithm will populate these futures with (value, expiration) when it finds the corresponding key
|
|
|
Note: canceling a future will stop search for the corresponding key
|
|
|
- :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh
|
|
|
+ :param _is_refresh: internal flag, set to True by an internal cache refresher (if enabled)
|
|
|
:returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key
|
|
|
:note: in order to check if get returned a value, please check (expiration_time is None)
|
|
|
"""
|
|
|
sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
|
|
|
beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
|
|
|
num_workers = num_workers if num_workers is not None else self.num_workers
|
|
|
- search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult(
|
|
|
- key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids}
|
|
|
+ search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
|
|
|
+ key_id, sufficient_expiration_time, serializer=self.protocol.serializer) for key_id in key_ids}
|
|
|
|
|
|
- if _refresh_cache:
|
|
|
+ if not _is_refresh: # if we're already refreshing cache, there's no need to trigger subsequent refreshes
|
|
|
for key_id in key_ids:
|
|
|
search_results[key_id].add_done_callback(self._trigger_cache_refresh)
|
|
|
|
|
@@ -387,7 +409,8 @@ class DHTNode:
|
|
|
# stage 1: check for value in this node's local storage and cache
|
|
|
for key_id in key_ids:
|
|
|
search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
|
|
|
- search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
|
|
|
+ if not _is_refresh:
|
|
|
+ search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
|
|
|
|
|
|
# stage 2: traverse the DHT to get the remaining keys from remote peers
|
|
|
unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
|
|
@@ -414,7 +437,7 @@ class DHTNode:
|
|
|
# V-- this function will be called exactly once when traverse_dht finishes search for a given key
|
|
|
async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
|
|
|
search_results[key_id].finish_search() # finish search whether or we found something
|
|
|
- self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint)
|
|
|
+ self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
|
|
|
|
|
|
asyncio.create_task(traverse_dht(
|
|
|
queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint),
|
|
@@ -433,9 +456,9 @@ class DHTNode:
|
|
|
search_result.future.cancel()
|
|
|
raise e
|
|
|
|
|
|
- def _reuse_finished_search_result(self, finished: _IntermediateResult):
|
|
|
+ def _reuse_finished_search_result(self, finished: _SearchState):
|
|
|
expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
|
|
|
- concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id]
|
|
|
+ concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
|
|
|
# note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
|
|
|
while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
|
|
|
concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
|
|
@@ -443,66 +466,72 @@ class DHTNode:
|
|
|
concurrent_requests[-1].finish_search()
|
|
|
concurrent_requests.pop(-1)
|
|
|
|
|
|
- def _trigger_cache_refresh(self, result: _IntermediateResult):
|
|
|
+ def _trigger_cache_refresh(self, search: _SearchState):
|
|
|
""" Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
|
|
|
- if result.found_something and result.source_node_id == self.node_id:
|
|
|
- with self.protocol.cache.freeze(): # do not clear outdated cache for now...
|
|
|
- if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache:
|
|
|
- previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top()
|
|
|
- self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time)
|
|
|
- if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]:
|
|
|
- self.cache_refresh_available.set() # if we new element is now earliest, notify the cache queue
|
|
|
+ if search.found_something and search.source_node_id == self.node_id:
|
|
|
+ if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
|
|
|
+ self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
|
|
|
+
|
|
|
+ def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
|
|
|
+ """ Add key to a refresh queue, refresh at :refresh_time: or later """
|
|
|
+ if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
|
|
|
+ self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
|
|
|
+ logger.debug("Spawned cache refresh task.")
|
|
|
+ previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
|
|
|
+ if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
|
|
|
+ self.cache_refresh_evt.set() # if we new element is now earliest, notify the cache queue
|
|
|
+ self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
|
|
|
|
|
|
async def _refresh_stale_cache_entries(self):
|
|
|
""" periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
|
|
|
while self.is_alive:
|
|
|
- with self.cache_refresh_queue.freeze():
|
|
|
- while len(self.cache_refresh_queue) == 0:
|
|
|
- await self.cache_refresh_available.wait()
|
|
|
- self.cache_refresh_available.clear()
|
|
|
- key_id, _, nearest_expiration = self.cache_refresh_queue.top()
|
|
|
+ while len(self.cache_refresh_queue) == 0:
|
|
|
+ await self.cache_refresh_evt.wait()
|
|
|
+ self.cache_refresh_evt.clear()
|
|
|
+ key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
|
|
|
|
|
|
try:
|
|
|
# step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
|
|
|
- time_to_wait = nearest_expiration - get_dht_time() - self.cache_refresh_before_expiry
|
|
|
- await asyncio.wait_for(self.cache_refresh_available.wait(), timeout=time_to_wait)
|
|
|
+ time_to_wait = nearest_refresh_time - get_dht_time()
|
|
|
+ await asyncio.wait_for(self.cache_refresh_evt.wait(), timeout=time_to_wait)
|
|
|
# note: the line above will cause TimeoutError when we are ready to refresh cache
|
|
|
- self.cache_refresh_available.clear() # no timeout error => someone added new entry to queue and ...
|
|
|
+ self.cache_refresh_evt.clear() # no timeout error => someone added new entry to queue and ...
|
|
|
continue # ... and this element is earlier than nearest_expiration. we should refresh this entry first
|
|
|
|
|
|
except asyncio.TimeoutError: # caught TimeoutError => it is time to refresh the most recent cached entry
|
|
|
# step 2: find all keys that we should already refresh and remove them from queue
|
|
|
- with self.cache_refresh_queue.freeze():
|
|
|
- keys_to_refresh = {key_id}
|
|
|
+ current_time = get_dht_time()
|
|
|
+ keys_to_refresh = {key_id}
|
|
|
+ max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
|
|
|
+ del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
+ while self.cache_refresh_queue:
|
|
|
+ key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
|
|
|
+ if nearest_refresh_time > current_time:
|
|
|
+ break
|
|
|
del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
- while self.cache_refresh_queue:
|
|
|
- key_id, _, nearest_expiration = self.cache_refresh_queue.top()
|
|
|
- if nearest_expiration > get_dht_time() + self.cache_refresh_before_expiry:
|
|
|
- break
|
|
|
- del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
- keys_to_refresh.add(key_id)
|
|
|
+ keys_to_refresh.add(key_id)
|
|
|
+ max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
|
|
|
|
|
|
# step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
|
|
|
- await self.get_many_by_id(
|
|
|
- keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry,
|
|
|
- _refresh_cache=False) # if we found value locally, we shouldn't trigger another refresh
|
|
|
+ sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
|
|
|
+ await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
|
|
|
|
|
|
- def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID],
|
|
|
- node_to_endpoint: Dict[DHTID, Endpoint]):
|
|
|
+ def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
|
|
|
+ node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
|
|
|
""" after key_id is found, update cache according to caching policy. used internally in get and get_many """
|
|
|
- if result.found_something:
|
|
|
- previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'),
|
|
|
- self.protocol.cache.get(result.key_id)[1] or -float('inf'))
|
|
|
- if result.expiration_time > previous_expiration_time: # if this value has better expiration
|
|
|
- if self.cache_locally:
|
|
|
- self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time)
|
|
|
+ if search.found_something:
|
|
|
+ previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
|
|
|
+ self.protocol.cache.get(search.key_id)[1] or -float('inf'))
|
|
|
+ if search.expiration_time > previous_expiration_time: # if this value has better expiration
|
|
|
+ if self.cache_locally or _is_refresh:
|
|
|
+ self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
|
|
|
if self.cache_nearest:
|
|
|
num_cached_nodes = 0
|
|
|
for node_id in nearest_nodes:
|
|
|
- if node_id == result.source_node_id:
|
|
|
+ if node_id == search.source_node_id:
|
|
|
continue
|
|
|
asyncio.create_task(self.protocol.call_store(
|
|
|
- node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time],
|
|
|
+ node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
|
|
|
in_cache=True))
|
|
|
num_cached_nodes += 1
|
|
|
if num_cached_nodes >= self.cache_nearest:
|
|
@@ -523,11 +552,11 @@ class DHTNode:
|
|
|
|
|
|
|
|
|
@dataclass(init=True, repr=True, frozen=False, order=False)
|
|
|
-class _IntermediateResult:
|
|
|
+class _SearchState:
|
|
|
""" A helper class that stores current-best GET results with metadata """
|
|
|
key_id: DHTID
|
|
|
sufficient_expiration_time: DHTExpiration
|
|
|
- binary_value: Optional[BinaryDHTValue] = None
|
|
|
+ binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
|
|
|
expiration_time: Optional[DHTExpiration] = None # best expiration time so far
|
|
|
source_node_id: Optional[DHTID] = None # node that gave us the value
|
|
|
future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
|
|
@@ -540,25 +569,33 @@ class _IntermediateResult:
|
|
|
if self.expiration_time >= self.sufficient_expiration_time:
|
|
|
self.finish_search()
|
|
|
|
|
|
- def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]):
|
|
|
- """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """
|
|
|
+ def add_done_callback(self, callback: Callable[[_SearchState], Any]):
|
|
|
+ """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
|
|
|
self.future.add_done_callback(lambda _future: callback(self))
|
|
|
|
|
|
def finish_search(self):
|
|
|
if self.future.done():
|
|
|
- return # either user cancelled our result or someone sent it before us. Nothing more to do here.
|
|
|
- deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None
|
|
|
- self.future.set_result((deserialized_value, self.expiration_time))
|
|
|
+ return # either user cancelled our search or someone sent it before us. Nothing more to do here.
|
|
|
+ elif not self.found_something:
|
|
|
+ self.future.set_result((None, None))
|
|
|
+ elif isinstance(self.binary_value, BinaryDHTValue):
|
|
|
+ self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
|
|
|
+ elif isinstance(self.binary_value, DictionaryDHTValue):
|
|
|
+ dict_value = {key: (self.serializer.loads(value), item_expiration_time)
|
|
|
+ for key, value, item_expiration_time in self.binary_value.items()}
|
|
|
+ self.future.set_result((dict_value, self.expiration_time))
|
|
|
+ else:
|
|
|
+ logger.error(f"Invalid value type: {type(self.binary_value)}")
|
|
|
|
|
|
@property
|
|
|
def found_something(self) -> bool:
|
|
|
- """ Whether or not we have at least some result, regardless of its expiration time """
|
|
|
+ """ Whether or not we have found at least some value, regardless of its expiration time """
|
|
|
return self.expiration_time is not None
|
|
|
|
|
|
@property
|
|
|
def finished(self) -> bool:
|
|
|
return self.future.done()
|
|
|
|
|
|
- def __lt__(self, other: _IntermediateResult):
|
|
|
- """ _IntermediateResult instances will be sorted by their target expiration time """
|
|
|
+ def __lt__(self, other: _SearchState):
|
|
|
+ """ _SearchState instances will be sorted by their target expiration time """
|
|
|
return self.sufficient_expiration_time < other.sufficient_expiration_time
|