|
@@ -7,7 +7,7 @@ from dataclasses import dataclass, field
|
|
|
from functools import partial
|
|
|
from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
|
|
|
|
|
|
-from sortedcontainers import SortedList
|
|
|
+from sortedcontainers import SortedSet
|
|
|
|
|
|
from hivemind.dht.protocol import DHTProtocol
|
|
|
from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
|
|
@@ -65,7 +65,7 @@ class DHTNode:
|
|
|
# fmt:off
|
|
|
node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol
|
|
|
chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
|
|
|
- cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]]
|
|
|
+ cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
|
|
|
cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
|
|
|
# fmt:on
|
|
|
|
|
@@ -115,7 +115,7 @@ class DHTNode:
|
|
|
self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh
|
|
|
|
|
|
self.reuse_get_requests = reuse_get_requests
|
|
|
- self.pending_get_requests = defaultdict(partial(SortedList, key=lambda _res: - _res.sufficient_expiration_time))
|
|
|
+ self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
|
|
|
|
|
|
# caching policy
|
|
|
self.refresh_timeout = refresh_timeout
|
|
@@ -468,14 +468,17 @@ class DHTNode:
|
|
|
raise e
|
|
|
|
|
|
def _reuse_finished_search_result(self, finished: _SearchState):
|
|
|
- search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
|
|
|
- expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time)
|
|
|
- concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id]
|
|
|
- # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time
|
|
|
- while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time:
|
|
|
- concurrent_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
|
|
|
- concurrent_requests[-1].finish_search()
|
|
|
- concurrent_requests.pop(-1)
|
|
|
+ pending_requests = self.pending_get_requests[finished.key_id]
|
|
|
+ if finished.found_something:
|
|
|
+ search_result = ValueWithExpiration(finished.binary_value, finished.expiration_time)
|
|
|
+ expiration_time_threshold = max(finished.expiration_time, finished.sufficient_expiration_time)
|
|
|
+ # note: pending_requests is sorted in the order of descending sufficient_expiration_time
|
|
|
+ while pending_requests and expiration_time_threshold >= pending_requests[-1].sufficient_expiration_time:
|
|
|
+ pending_requests[-1].add_candidate(search_result, source_node_id=finished.source_node_id)
|
|
|
+ pending_requests[-1].finish_search()
|
|
|
+ pending_requests.pop()
|
|
|
+ else:
|
|
|
+ pending_requests.discard(finished)
|
|
|
|
|
|
def _trigger_cache_refresh(self, search: _SearchState):
|
|
|
""" Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
|
|
@@ -613,3 +616,6 @@ class _SearchState:
|
|
|
def __lt__(self, other: _SearchState):
|
|
|
""" _SearchState instances will be sorted by their target expiration time """
|
|
|
return self.sufficient_expiration_time < other.sufficient_expiration_time
|
|
|
+
|
|
|
+ def __hash__(self):
|
|
|
+ return hash(self.key_id)
|