|
@@ -12,7 +12,7 @@ from sortedcontainers import SortedList
|
|
|
|
|
|
from hivemind.dht.protocol import DHTProtocol
|
|
|
from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
|
|
|
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
|
|
|
+from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue, ValueWithExpiration
|
|
|
from hivemind.dht.traverse import traverse_dht
|
|
|
from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
|
|
|
|
|
@@ -189,7 +189,7 @@ class DHTNode:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
|
output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
|
|
|
- for query, (_, _, peers) in response.items():
|
|
|
+ for query, (_, peers) in response.items():
|
|
|
node_to_endpoint.update(peers)
|
|
|
output[query] = tuple(peers.keys()), False # False means "do not interrupt search"
|
|
|
return output
|
|
@@ -237,8 +237,8 @@ class DHTNode:
|
|
|
"""
|
|
|
if isinstance(expiration_time, DHTExpiration):
|
|
|
expiration_time = [expiration_time] * len(keys)
|
|
|
- if subkeys is None or isinstance(subkeys, Subkey):
|
|
|
- subkeys = [subkeys] * len(keys)
|
|
|
+ if subkeys is None:
|
|
|
+ subkeys = [None] * len(keys)
|
|
|
|
|
|
assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
|
|
|
"Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
|
|
@@ -336,13 +336,13 @@ class DHTNode:
|
|
|
self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
|
|
|
self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
|
|
|
|
|
|
- async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
|
|
|
+ async def get(self, key: DHTKey, latest=False, **kwargs) -> Optional[ValueWithExpiration[DHTValue]]:
|
|
|
"""
|
|
|
- Search for a key across DHT and return either first or latest entry.
|
|
|
+ Search for a key across DHT and return either first or latest entry (if found).
|
|
|
:param key: same key as in node.store(...)
|
|
|
:param latest: if True, finds the latest value, otherwise finds any non-expired value (which is much faster)
|
|
|
:param kwargs: parameters forwarded to get_many_by_id
|
|
|
- :returns: (value, expiration time); if value was not found, returns (None, None)
|
|
|
+ :returns: (value, expiration time); if value was not found, returns None
|
|
|
"""
|
|
|
if latest:
|
|
|
kwargs["sufficient_expiration_time"] = float('inf')
|
|
@@ -350,8 +350,8 @@ class DHTNode:
|
|
|
return result[key]
|
|
|
|
|
|
async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
- **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
- Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
+ **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
|
|
|
+ Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
|
|
|
"""
|
|
|
Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
|
|
|
|
|
@@ -372,8 +372,8 @@ class DHTNode:
|
|
|
async def get_many_by_id(
|
|
|
self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
|
|
|
- _is_refresh=False) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]],
|
|
|
- Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]:
|
|
|
+ _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
|
|
|
+ Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
|
|
|
"""
|
|
|
Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
|
|
|
|
|
@@ -408,9 +408,9 @@ class DHTNode:
|
|
|
|
|
|
# stage 1: check for value in this node's local storage and cache
|
|
|
for key_id in key_ids:
|
|
|
- search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id)
|
|
|
+ search_results[key_id].add_candidate(self.protocol.storage.get(key_id), source_node_id=self.node_id)
|
|
|
if not _is_refresh:
|
|
|
- search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id)
|
|
|
+ search_results[key_id].add_candidate(self.protocol.cache.get(key_id), source_node_id=self.node_id)
|
|
|
|
|
|
# stage 2: traverse the DHT to get the remaining keys from remote peers
|
|
|
unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
|
|
@@ -427,9 +427,9 @@ class DHTNode:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
|
output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
|
|
|
- for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items():
|
|
|
+ for key_id, (maybe_value_with_expiration, peers) in response.items():
|
|
|
node_to_endpoint.update(peers)
|
|
|
- search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer)
|
|
|
+ search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
|
|
|
output[key_id] = tuple(peers.keys()), search_results[key_id].finished
|
|
|
# note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
|
|
|
return output
|
|
@@ -457,12 +457,12 @@ class DHTNode:
|
|
|
raise e
|
|
|
|
|
|
def _reuse_finished_search_result(self, finished: _SearchState):
|
|
|
+ search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
|
|
|
expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
|
|
|
concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
|
|
|
- # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time
|
|
|
+ # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
|
|
|
while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
|
|
|
- concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time,
|
|
|
- source_node_id=finished.source_node_id)
|
|
|
+ concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
|
|
|
concurrent_requests[-1].finish_search()
|
|
|
concurrent_requests.pop(-1)
|
|
|
|
|
@@ -477,8 +477,8 @@ class DHTNode:
|
|
|
if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
|
|
|
self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
|
|
|
logger.debug("Spawned cache refresh task.")
|
|
|
- previous_earliest_item: Tuple[DHTID, Any, DHTExpiration] = self.cache_refresh_queue.top()
|
|
|
- if previous_earliest_item is None or refresh_time < previous_earliest_item[-1]:
|
|
|
+ earliest_key, earliest_item = self.cache_refresh_queue.top()
|
|
|
+ if earliest_item is None or refresh_time < earliest_item.expiration_time:
|
|
|
self.cache_refresh_evt.set() # if we new element is now earliest, notify the cache queue
|
|
|
self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
|
|
|
|
|
@@ -488,7 +488,7 @@ class DHTNode:
|
|
|
while len(self.cache_refresh_queue) == 0:
|
|
|
await self.cache_refresh_evt.wait()
|
|
|
self.cache_refresh_evt.clear()
|
|
|
- key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
|
|
|
+ key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
|
|
|
|
|
|
try:
|
|
|
# step 1: await until :cache_refresh_before_expiry: seconds before earliest first element expires
|
|
@@ -505,12 +505,14 @@ class DHTNode:
|
|
|
max_expiration_time = self.protocol.cache.get(key_id)[1] or current_time
|
|
|
del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
while self.cache_refresh_queue:
|
|
|
- key_id, _, nearest_refresh_time = self.cache_refresh_queue.top()
|
|
|
+ key_id, (_, nearest_refresh_time) = self.cache_refresh_queue.top()
|
|
|
if nearest_refresh_time > current_time:
|
|
|
break
|
|
|
del self.cache_refresh_queue[key_id] # we pledge to refresh this key_id in the nearest batch
|
|
|
keys_to_refresh.add(key_id)
|
|
|
- max_expiration_time = max(max_expiration_time, self.protocol.cache.get(key_id)[1] or current_time)
|
|
|
+ cached_item = self.protocol.cache.get(key_id)
|
|
|
+ if cached_item is not None and cached_item.expiration_time > max_expiration_time:
|
|
|
+ max_expiration_time = cached_item.expiration_time
|
|
|
|
|
|
# step 3: search newer versions of these keys, cache them as a side-effect of self.get_many_by_id
|
|
|
sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
|
|
@@ -520,9 +522,10 @@ class DHTNode:
|
|
|
node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
|
|
|
""" after key_id is found, update cache according to caching policy. used internally in get and get_many """
|
|
|
if search.found_something:
|
|
|
- previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'),
|
|
|
- self.protocol.cache.get(search.key_id)[1] or -float('inf'))
|
|
|
- if search.expiration_time > previous_expiration_time: # if this value has better expiration
|
|
|
+ _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
|
|
|
+ _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
|
|
|
+
|
|
|
+ if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
|
|
|
if self.cache_locally or _is_refresh:
|
|
|
self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time)
|
|
|
if self.cache_nearest:
|
|
@@ -559,12 +562,12 @@ class _SearchState:
|
|
|
binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
|
|
|
expiration_time: Optional[DHTExpiration] = None # best expiration time so far
|
|
|
source_node_id: Optional[DHTID] = None # node that gave us the value
|
|
|
- future: asyncio.Future[Tuple[Optional[DHTValue], Optional[DHTExpiration]]] = field(default_factory=asyncio.Future)
|
|
|
+ future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
|
|
|
serializer: type(SerializerBase) = MSGPackSerializer
|
|
|
|
|
|
- def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: Optional[DHTExpiration],
|
|
|
- source_node_id: Optional[DHTID]):
|
|
|
- if not self.finished and (expiration_time or -float('inf')) > (self.expiration_time or -float('inf')):
|
|
|
+ def add_candidate(self, candidate: Optional[ValueWithExpiration[BinaryDHTValue]], source_node_id: Optional[DHTID]):
|
|
|
+ binary_value, expiration_time = candidate or (None, -float('inf'))
|
|
|
+ if not self.finished and expiration_time > (self.expiration_time or -float('inf')):
|
|
|
self.binary_value, self.expiration_time, self.source_node_id = binary_value, expiration_time, source_node_id
|
|
|
if self.expiration_time >= self.sufficient_expiration_time:
|
|
|
self.finish_search()
|
|
@@ -577,13 +580,13 @@ class _SearchState:
|
|
|
if self.future.done():
|
|
|
return # either user cancelled our search or someone sent it before us. Nothing more to do here.
|
|
|
elif not self.found_something:
|
|
|
- self.future.set_result((None, None))
|
|
|
+ self.future.set_result(None)
|
|
|
elif isinstance(self.binary_value, BinaryDHTValue):
|
|
|
- self.future.set_result((self.serializer.loads(self.binary_value), self.expiration_time))
|
|
|
+ self.future.set_result(ValueWithExpiration(self.serializer.loads(self.binary_value), self.expiration_time))
|
|
|
elif isinstance(self.binary_value, DictionaryDHTValue):
|
|
|
- dict_value = {key: (self.serializer.loads(value), item_expiration_time)
|
|
|
- for key, value, item_expiration_time in self.binary_value.items()}
|
|
|
- self.future.set_result((dict_value, self.expiration_time))
|
|
|
+ dict_with_subkeys = {key: ValueWithExpiration(self.serializer.loads(value), item_expiration_time)
|
|
|
+ for key, (value, item_expiration_time) in self.binary_value.items()}
|
|
|
+ self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
|
|
|
else:
|
|
|
logger.error(f"Invalid value type: {type(self.binary_value)}")
|
|
|
|