|
@@ -6,8 +6,21 @@ import random
|
|
|
from collections import defaultdict, Counter
|
|
|
from dataclasses import dataclass, field
|
|
|
from functools import partial
|
|
|
-from typing import (Any, Awaitable, Callable, Collection, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple,
|
|
|
- Type, Union)
|
|
|
+from typing import (
|
|
|
+ Any,
|
|
|
+ Awaitable,
|
|
|
+ Callable,
|
|
|
+ Collection,
|
|
|
+ DefaultDict,
|
|
|
+ Dict,
|
|
|
+ List,
|
|
|
+ Optional,
|
|
|
+ Sequence,
|
|
|
+ Set,
|
|
|
+ Tuple,
|
|
|
+ Type,
|
|
|
+ Union,
|
|
|
+)
|
|
|
|
|
|
from multiaddr import Multiaddr
|
|
|
from sortedcontainers import SortedSet
|
|
@@ -69,6 +82,7 @@ class DHTNode:
|
|
|
to reuse the result of this GET request for other requests with the same key. Useful for batch-parallel requests.
|
|
|
|
|
|
"""
|
|
|
+
|
|
|
# fmt:off
|
|
|
node_id: DHTID; is_alive: bool; peer_id: PeerID; num_replicas: int; num_workers: int; protocol: DHTProtocol
|
|
|
chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
|
|
@@ -79,19 +93,34 @@ class DHTNode:
|
|
|
|
|
|
@classmethod
|
|
|
async def create(
|
|
|
- cls,
|
|
|
- p2p: Optional[P2P] = None,
|
|
|
- node_id: Optional[DHTID] = None,
|
|
|
- initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
|
|
|
- bucket_size: int = 20, num_replicas: int = 5, depth_modulo: int = 5, parallel_rpc: int = None,
|
|
|
- wait_timeout: float = 3, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
|
|
|
- cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
|
|
|
- cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
|
|
|
- blacklist_time: float = 5.0, backoff_rate: float = 2.0,
|
|
|
- listen: bool = True,
|
|
|
- record_validator: Optional[RecordValidatorBase] = None,
|
|
|
- authorizer: Optional[AuthorizerBase] = None,
|
|
|
- validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
|
|
|
+ cls,
|
|
|
+ p2p: Optional[P2P] = None,
|
|
|
+ node_id: Optional[DHTID] = None,
|
|
|
+ initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
|
|
|
+ bucket_size: int = 20,
|
|
|
+ num_replicas: int = 5,
|
|
|
+ depth_modulo: int = 5,
|
|
|
+ parallel_rpc: int = None,
|
|
|
+ wait_timeout: float = 3,
|
|
|
+ refresh_timeout: Optional[float] = None,
|
|
|
+ bootstrap_timeout: Optional[float] = None,
|
|
|
+ cache_locally: bool = True,
|
|
|
+ cache_nearest: int = 1,
|
|
|
+ cache_size=None,
|
|
|
+ cache_refresh_before_expiry: float = 5,
|
|
|
+ cache_on_store: bool = True,
|
|
|
+ reuse_get_requests: bool = True,
|
|
|
+ num_workers: int = 1,
|
|
|
+ chunk_size: int = 16,
|
|
|
+ blacklist_time: float = 5.0,
|
|
|
+ backoff_rate: float = 2.0,
|
|
|
+ listen: bool = True,
|
|
|
+ record_validator: Optional[RecordValidatorBase] = None,
|
|
|
+ authorizer: Optional[AuthorizerBase] = None,
|
|
|
+ validate: bool = True,
|
|
|
+ strict: bool = True,
|
|
|
+ **kwargs,
|
|
|
+ ) -> DHTNode:
|
|
|
"""
|
|
|
:param p2p: instance of hivemind.p2p.P2P that will be used for communication.
|
|
|
If None, DHTNode will create and manage its own P2P instance with given initial_peers and
|
|
@@ -139,7 +168,7 @@ class DHTNode:
|
|
|
self.is_alive = True # if set to False, cancels all background jobs such as routing table refresh
|
|
|
|
|
|
self.reuse_get_requests = reuse_get_requests
|
|
|
- self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: - _res.sufficient_expiration_time))
|
|
|
+ self.pending_get_requests = defaultdict(partial(SortedSet, key=lambda _res: -_res.sufficient_expiration_time))
|
|
|
|
|
|
# caching policy
|
|
|
self.refresh_timeout = refresh_timeout
|
|
@@ -151,38 +180,52 @@ class DHTNode:
|
|
|
self.cache_refresh_task = None
|
|
|
|
|
|
if p2p is None:
|
|
|
- if not kwargs.get('use_ipfs'):
|
|
|
- kwargs['initial_peers'] = initial_peers
|
|
|
+ if not kwargs.get("use_ipfs"):
|
|
|
+ kwargs["initial_peers"] = initial_peers
|
|
|
p2p = await P2P.create(**kwargs)
|
|
|
self._should_shutdown_p2p = True
|
|
|
else:
|
|
|
if kwargs:
|
|
|
raise ValueError(
|
|
|
- f'**kwargs in DHTNode.create() should be empty if hivemind.p2p.P2P instance is provided'
|
|
|
- f'in the constructor. Got kwargs = {kwargs} instead. '
|
|
|
- f'You may have a typo in a DHTNode.create() parameter name')
|
|
|
+ f"**kwargs in DHTNode.create() should be empty if hivemind.p2p.P2P instance is provided"
|
|
|
+ f"in the constructor. Got kwargs = {kwargs} instead. "
|
|
|
+ f"You may have a typo in a DHTNode.create() parameter name"
|
|
|
+ )
|
|
|
self._should_shutdown_p2p = False
|
|
|
self.p2p = p2p
|
|
|
|
|
|
self.protocol = await DHTProtocol.create(
|
|
|
- p2p, self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
|
|
|
- parallel_rpc, cache_size, listen, record_validator, authorizer)
|
|
|
+ p2p,
|
|
|
+ self.node_id,
|
|
|
+ bucket_size,
|
|
|
+ depth_modulo,
|
|
|
+ num_replicas,
|
|
|
+ wait_timeout,
|
|
|
+ parallel_rpc,
|
|
|
+ cache_size,
|
|
|
+ listen,
|
|
|
+ record_validator,
|
|
|
+ authorizer,
|
|
|
+ )
|
|
|
self.peer_id = p2p.id
|
|
|
|
|
|
if initial_peers:
|
|
|
- initial_peers = {PeerID.from_base58(Multiaddr(item)['p2p']) for item in initial_peers}
|
|
|
+ initial_peers = {PeerID.from_base58(Multiaddr(item)["p2p"]) for item in initial_peers}
|
|
|
|
|
|
# stage 1: ping initial_peers, add each other to the routing table
|
|
|
bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
|
|
|
start_time = get_dht_time()
|
|
|
- ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
|
|
|
- for peer in initial_peers)
|
|
|
+ ping_tasks = set(
|
|
|
+ asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
|
|
|
+ for peer in initial_peers
|
|
|
+ )
|
|
|
finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
|
|
# stage 2: gather remaining peers (those who respond within bootstrap_timeout)
|
|
|
if unfinished_pings:
|
|
|
finished_in_time, stragglers = await asyncio.wait(
|
|
|
- unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time)
|
|
|
+ unfinished_pings, timeout=bootstrap_timeout - get_dht_time() + start_time
|
|
|
+ )
|
|
|
for straggler in stragglers:
|
|
|
straggler.cancel()
|
|
|
finished_pings |= finished_in_time
|
|
@@ -197,29 +240,39 @@ class DHTNode:
|
|
|
# stage 3: traverse dht to find my own nearest neighbors and populate the routing table
|
|
|
# ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
|
|
|
# note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
|
|
|
- await asyncio.wait([asyncio.create_task(self.find_nearest_nodes([self.node_id])),
|
|
|
- asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time)],
|
|
|
- return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ await asyncio.wait(
|
|
|
+ [
|
|
|
+ asyncio.create_task(self.find_nearest_nodes([self.node_id])),
|
|
|
+ asyncio.sleep(bootstrap_timeout - get_dht_time() + start_time),
|
|
|
+ ],
|
|
|
+ return_when=asyncio.FIRST_COMPLETED,
|
|
|
+ )
|
|
|
|
|
|
if self.refresh_timeout is not None:
|
|
|
asyncio.create_task(self._refresh_routing_table(period=self.refresh_timeout))
|
|
|
return self
|
|
|
|
|
|
def __init__(self, *, _initialized_with_create=False):
|
|
|
- """ Internal init method. Please use DHTNode.create coroutine to spawn new node instances """
|
|
|
+ """Internal init method. Please use DHTNode.create coroutine to spawn new node instances"""
|
|
|
assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
|
|
|
super().__init__()
|
|
|
|
|
|
async def shutdown(self):
|
|
|
- """ Process existing requests, close all connections and stop the server """
|
|
|
+ """Process existing requests, close all connections and stop the server"""
|
|
|
self.is_alive = False
|
|
|
if self._should_shutdown_p2p:
|
|
|
await self.p2p.shutdown()
|
|
|
|
|
|
async def find_nearest_nodes(
|
|
|
- self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
|
|
|
- num_workers: Optional[int] = None, node_to_peer_id: Optional[Dict[DHTID, PeerID]] = None,
|
|
|
- exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, PeerID]]:
|
|
|
+ self,
|
|
|
+ queries: Collection[DHTID],
|
|
|
+ k_nearest: Optional[int] = None,
|
|
|
+ beam_size: Optional[int] = None,
|
|
|
+ num_workers: Optional[int] = None,
|
|
|
+ node_to_peer_id: Optional[Dict[DHTID, PeerID]] = None,
|
|
|
+ exclude_self: bool = False,
|
|
|
+ **kwargs,
|
|
|
+ ) -> Dict[DHTID, Dict[DHTID, PeerID]]:
|
|
|
"""
|
|
|
:param queries: find k nearest nodes for each of these DHTIDs
|
|
|
:param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
|
|
@@ -254,9 +307,15 @@ class DHTNode:
|
|
|
return output
|
|
|
|
|
|
nearest_nodes_per_query, visited_nodes = await traverse_dht(
|
|
|
- queries, initial_nodes=list(node_to_peer_id), beam_size=beam_size, num_workers=num_workers,
|
|
|
- queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
|
|
|
- visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
|
|
|
+ queries,
|
|
|
+ initial_nodes=list(node_to_peer_id),
|
|
|
+ beam_size=beam_size,
|
|
|
+ num_workers=num_workers,
|
|
|
+ queries_per_call=int(len(queries) ** 0.5),
|
|
|
+ get_neighbors=get_neighbors,
|
|
|
+ visited_nodes={query: {self.node_id} for query in queries},
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
|
|
|
nearest_nodes_with_peer_ids = {}
|
|
|
for query, nearest_nodes in nearest_nodes_per_query.items():
|
|
@@ -266,8 +325,9 @@ class DHTNode:
|
|
|
nearest_nodes_with_peer_ids[query] = {node: node_to_peer_id[node] for node in nearest_nodes[:k_nearest]}
|
|
|
return nearest_nodes_with_peer_ids
|
|
|
|
|
|
- async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
|
|
|
- subkey: Optional[Subkey] = None, **kwargs) -> bool:
|
|
|
+ async def store(
|
|
|
+ self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, subkey: Optional[Subkey] = None, **kwargs
|
|
|
+ ) -> bool:
|
|
|
"""
|
|
|
Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
|
|
|
:note: store is a simplified interface to store_many, all kwargs are be forwarded there
|
|
@@ -276,10 +336,16 @@ class DHTNode:
|
|
|
store_ok = await self.store_many([key], [value], [expiration_time], subkeys=[subkey], **kwargs)
|
|
|
return store_ok[(key, subkey) if subkey is not None else key]
|
|
|
|
|
|
- async def store_many(self, keys: List[DHTKey], values: List[DHTValue],
|
|
|
- expiration_time: Union[DHTExpiration, List[DHTExpiration]],
|
|
|
- subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
|
|
|
- exclude_self: bool = False, await_all_replicas=True, **kwargs) -> Dict[DHTKey, bool]:
|
|
|
+ async def store_many(
|
|
|
+ self,
|
|
|
+ keys: List[DHTKey],
|
|
|
+ values: List[DHTValue],
|
|
|
+ expiration_time: Union[DHTExpiration, List[DHTExpiration]],
|
|
|
+ subkeys: Optional[Union[Subkey, List[Optional[Subkey]]]] = None,
|
|
|
+ exclude_self: bool = False,
|
|
|
+ await_all_replicas=True,
|
|
|
+ **kwargs,
|
|
|
+ ) -> Dict[DHTKey, bool]:
|
|
|
"""
|
|
|
Traverse DHT to find up :num_replicas: to best nodes to store multiple (key, value, expiration_time) pairs.
|
|
|
|
|
@@ -299,8 +365,9 @@ class DHTNode:
|
|
|
if subkeys is None:
|
|
|
subkeys = [None] * len(keys)
|
|
|
|
|
|
- assert len(keys) == len(subkeys) == len(values) == len(expiration_time), \
|
|
|
- "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
|
|
|
+ assert (
|
|
|
+ len(keys) == len(subkeys) == len(values) == len(expiration_time)
|
|
|
+ ), "Either of keys, values, subkeys or expiration timestamps have different sequence lengths."
|
|
|
|
|
|
key_id_to_data: DefaultDict[DHTID, List[Tuple[DHTKey, Subkey, DHTValue, DHTExpiration]]] = defaultdict(list)
|
|
|
for key, subkey, value, expiration in zip(keys, subkeys, values, expiration_time):
|
|
@@ -313,11 +380,14 @@ class DHTNode:
|
|
|
# pre-populate node_to_peer_id
|
|
|
node_to_peer_id: Dict[DHTID, PeerID] = dict()
|
|
|
for key_id in unfinished_key_ids:
|
|
|
- node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
|
|
|
- key_id, self.protocol.bucket_size, exclude=self.node_id))
|
|
|
+ node_to_peer_id.update(
|
|
|
+ self.protocol.routing_table.get_nearest_neighbors(
|
|
|
+ key_id, self.protocol.bucket_size, exclude=self.node_id
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
|
|
|
- """ This will be called once per key when find_nearest_nodes is done for a particular node """
|
|
|
+ """This will be called once per key when find_nearest_nodes is done for a particular node"""
|
|
|
# note: we use callbacks instead of returned values to call store immediately without waiting for stragglers
|
|
|
assert key_id in unfinished_key_ids, "Internal error: traverse_dht finished the same query twice"
|
|
|
assert self.node_id not in nearest_nodes
|
|
@@ -326,15 +396,15 @@ class DHTNode:
|
|
|
# ensure k nodes stored the value, optionally include self.node_id as a candidate
|
|
|
num_successful_stores = 0
|
|
|
pending_store_tasks = set()
|
|
|
- store_candidates = sorted(nearest_nodes + ([] if exclude_self else [self.node_id]),
|
|
|
- key=key_id.xor_distance, reverse=True) # ordered so that .pop() returns nearest
|
|
|
+ store_candidates = sorted(
|
|
|
+ nearest_nodes + ([] if exclude_self else [self.node_id]), key=key_id.xor_distance, reverse=True
|
|
|
+ ) # ordered so that .pop() returns nearest
|
|
|
[original_key, *_], current_subkeys, current_values, current_expirations = zip(*key_id_to_data[key_id])
|
|
|
|
|
|
key_bytes = key_id.to_bytes()
|
|
|
binary_values = []
|
|
|
stored_records = []
|
|
|
- for subkey, value, expiration_time in zip(
|
|
|
- current_subkeys, current_values, current_expirations):
|
|
|
+ for subkey, value, expiration_time in zip(current_subkeys, current_values, current_expirations):
|
|
|
subkey_bytes = self.protocol.serializer.dumps(subkey)
|
|
|
value_bytes = self.protocol.serializer.dumps(value)
|
|
|
record = DHTRecord(key_bytes, subkey_bytes, value_bytes, expiration_time)
|
|
@@ -351,23 +421,34 @@ class DHTNode:
|
|
|
if node_id == self.node_id:
|
|
|
num_successful_stores += 1
|
|
|
for subkey, record in zip(current_subkeys, stored_records):
|
|
|
- if (self.protocol.record_validator is None or
|
|
|
- self.protocol.record_validator.validate(record)):
|
|
|
+ if self.protocol.record_validator is None or self.protocol.record_validator.validate(
|
|
|
+ record
|
|
|
+ ):
|
|
|
store_ok[original_key, subkey] = self.protocol.storage.store(
|
|
|
- key_id, record.value, record.expiration_time, subkey=subkey)
|
|
|
+ key_id, record.value, record.expiration_time, subkey=subkey
|
|
|
+ )
|
|
|
else:
|
|
|
store_ok[original_key, subkey] = False
|
|
|
if not await_all_replicas:
|
|
|
store_finished_events[original_key, subkey].set()
|
|
|
else:
|
|
|
- pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
|
|
|
- node_to_peer_id[node_id], keys=[key_id] * len(current_values), values=binary_values,
|
|
|
- expiration_time=current_expirations, subkeys=current_subkeys)))
|
|
|
+ pending_store_tasks.add(
|
|
|
+ asyncio.create_task(
|
|
|
+ self.protocol.call_store(
|
|
|
+ node_to_peer_id[node_id],
|
|
|
+ keys=[key_id] * len(current_values),
|
|
|
+ values=binary_values,
|
|
|
+ expiration_time=current_expirations,
|
|
|
+ subkeys=current_subkeys,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
# await nearest task. If it fails, dispatch more on the next iteration
|
|
|
if pending_store_tasks:
|
|
|
finished_store_tasks, pending_store_tasks = await asyncio.wait(
|
|
|
- pending_store_tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ pending_store_tasks, return_when=asyncio.FIRST_COMPLETED
|
|
|
+ )
|
|
|
for task in finished_store_tasks:
|
|
|
if task.result() is not None:
|
|
|
num_successful_stores += 1
|
|
@@ -377,27 +458,47 @@ class DHTNode:
|
|
|
store_finished_events[original_key, subkey].set()
|
|
|
|
|
|
if self.cache_on_store:
|
|
|
- self._update_cache_on_store(key_id, current_subkeys, binary_values, current_expirations,
|
|
|
- store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys])
|
|
|
+ self._update_cache_on_store(
|
|
|
+ key_id,
|
|
|
+ current_subkeys,
|
|
|
+ binary_values,
|
|
|
+ current_expirations,
|
|
|
+ store_ok=[store_ok[original_key, subkey] for subkey in current_subkeys],
|
|
|
+ )
|
|
|
|
|
|
for subkey, value_bytes, expiration in zip(current_subkeys, binary_values, current_expirations):
|
|
|
store_finished_events[original_key, subkey].set()
|
|
|
|
|
|
- store_task = asyncio.create_task(self.find_nearest_nodes(
|
|
|
- queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_peer_id=node_to_peer_id,
|
|
|
- found_callback=on_found, exclude_self=exclude_self, **kwargs))
|
|
|
+ store_task = asyncio.create_task(
|
|
|
+ self.find_nearest_nodes(
|
|
|
+ queries=set(unfinished_key_ids),
|
|
|
+ k_nearest=self.num_replicas,
|
|
|
+ node_to_peer_id=node_to_peer_id,
|
|
|
+ found_callback=on_found,
|
|
|
+ exclude_self=exclude_self,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ )
|
|
|
try:
|
|
|
await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
|
|
|
assert len(unfinished_key_ids) == 0, "Internal error: traverse_dht didn't finish search"
|
|
|
- return {(key, subkey) if subkey is not None else key: status or False
|
|
|
- for (key, subkey), status in store_ok.items()}
|
|
|
+ return {
|
|
|
+ (key, subkey) if subkey is not None else key: status or False
|
|
|
+ for (key, subkey), status in store_ok.items()
|
|
|
+ }
|
|
|
except asyncio.CancelledError as e:
|
|
|
store_task.cancel()
|
|
|
raise e
|
|
|
|
|
|
- def _update_cache_on_store(self, key_id: DHTID, subkeys: List[Subkey], binary_values: List[bytes],
|
|
|
- expirations: List[DHTExpiration], store_ok: List[bool]):
|
|
|
- """ Update local cache after finishing a store for one key (with perhaps several subkeys) """
|
|
|
+ def _update_cache_on_store(
|
|
|
+ self,
|
|
|
+ key_id: DHTID,
|
|
|
+ subkeys: List[Subkey],
|
|
|
+ binary_values: List[bytes],
|
|
|
+ expirations: List[DHTExpiration],
|
|
|
+ store_ok: List[bool],
|
|
|
+ ):
|
|
|
+ """Update local cache after finishing a store for one key (with perhaps several subkeys)"""
|
|
|
store_succeeded = any(store_ok)
|
|
|
is_dictionary = any(subkey is not None for subkey in subkeys)
|
|
|
if store_succeeded and not is_dictionary: # stored a new regular value, cache it!
|
|
@@ -406,12 +507,14 @@ class DHTNode:
|
|
|
elif not store_succeeded and not is_dictionary: # store rejected, check if local cache is also obsolete
|
|
|
rejected_expiration, rejected_value = max(zip(expirations, binary_values))
|
|
|
cached_value = self.protocol.cache.get(key_id)
|
|
|
- if (cached_value is not None and
|
|
|
- cached_value.expiration_time <= rejected_expiration): # cache would be rejected
|
|
|
+ if (
|
|
|
+ cached_value is not None and cached_value.expiration_time <= rejected_expiration
|
|
|
+ ): # cache would be rejected
|
|
|
self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
|
|
|
elif is_dictionary and key_id in self.protocol.cache: # there can be other keys and we should update
|
|
|
for subkey, stored_value_bytes, expiration_time, accepted in zip(
|
|
|
- subkeys, binary_values, expirations, store_ok):
|
|
|
+ subkeys, binary_values, expirations, store_ok
|
|
|
+ ):
|
|
|
if accepted:
|
|
|
self.protocol.cache.store_subkey(key_id, subkey, stored_value_bytes, expiration_time)
|
|
|
self._schedule_for_refresh(key_id, refresh_time=get_dht_time()) # fetch new key in background (asap)
|
|
@@ -425,13 +528,15 @@ class DHTNode:
|
|
|
:returns: (value, expiration time); if value was not found, returns None
|
|
|
"""
|
|
|
if latest:
|
|
|
- kwargs["sufficient_expiration_time"] = float('inf')
|
|
|
+ kwargs["sufficient_expiration_time"] = float("inf")
|
|
|
result = await self.get_many([key], **kwargs)
|
|
|
return result[key]
|
|
|
|
|
|
- async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
- **kwargs) -> Dict[DHTKey, Union[Optional[ValueWithExpiration[DHTValue]],
|
|
|
- Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
|
|
|
+ async def get_many(
|
|
|
+ self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, **kwargs
|
|
|
+ ) -> Dict[
|
|
|
+ DHTKey, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]
|
|
|
+ ]:
|
|
|
"""
|
|
|
Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found.
|
|
|
|
|
@@ -450,10 +555,16 @@ class DHTNode:
|
|
|
return {id_to_original_key[key]: result_or_future for key, result_or_future in results_by_id.items()}
|
|
|
|
|
|
async def get_many_by_id(
|
|
|
- self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
- num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False,
|
|
|
- _is_refresh=False) -> Dict[DHTID, Union[Optional[ValueWithExpiration[DHTValue]],
|
|
|
- Awaitable[Optional[ValueWithExpiration[DHTValue]]]]]:
|
|
|
+ self,
|
|
|
+ key_ids: Collection[DHTID],
|
|
|
+ sufficient_expiration_time: Optional[DHTExpiration] = None,
|
|
|
+ num_workers: Optional[int] = None,
|
|
|
+ beam_size: Optional[int] = None,
|
|
|
+ return_futures: bool = False,
|
|
|
+ _is_refresh=False,
|
|
|
+ ) -> Dict[
|
|
|
+ DHTID, Union[Optional[ValueWithExpiration[DHTValue]], Awaitable[Optional[ValueWithExpiration[DHTValue]]]]
|
|
|
+ ]:
|
|
|
"""
|
|
|
Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found.
|
|
|
|
|
@@ -473,10 +584,15 @@ class DHTNode:
|
|
|
sufficient_expiration_time = sufficient_expiration_time or get_dht_time()
|
|
|
beam_size = beam_size if beam_size is not None else self.protocol.bucket_size
|
|
|
num_workers = num_workers if num_workers is not None else self.num_workers
|
|
|
- search_results: Dict[DHTID, _SearchState] = {key_id: _SearchState(
|
|
|
- key_id, sufficient_expiration_time,
|
|
|
- serializer=self.protocol.serializer,
|
|
|
- record_validator=self.protocol.record_validator) for key_id in key_ids}
|
|
|
+ search_results: Dict[DHTID, _SearchState] = {
|
|
|
+ key_id: _SearchState(
|
|
|
+ key_id,
|
|
|
+ sufficient_expiration_time,
|
|
|
+ serializer=self.protocol.serializer,
|
|
|
+ record_validator=self.protocol.record_validator,
|
|
|
+ )
|
|
|
+ for key_id in key_ids
|
|
|
+ }
|
|
|
|
|
|
if not _is_refresh: # if we're already refreshing cache, there's no need to trigger subsequent refreshes
|
|
|
for key_id in key_ids:
|
|
@@ -498,8 +614,11 @@ class DHTNode:
|
|
|
unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
|
|
|
node_to_peer_id: Dict[DHTID, PeerID] = dict() # global routing table for all keys
|
|
|
for key_id in unfinished_key_ids:
|
|
|
- node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
|
|
|
- key_id, self.protocol.bucket_size, exclude=self.node_id))
|
|
|
+ node_to_peer_id.update(
|
|
|
+ self.protocol.routing_table.get_nearest_neighbors(
|
|
|
+ key_id, self.protocol.bucket_size, exclude=self.node_id
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
# V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
|
|
|
async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
|
|
@@ -521,11 +640,19 @@ class DHTNode:
|
|
|
search_results[key_id].finish_search() # finish search whether or we found something
|
|
|
self._cache_new_result(search_results[key_id], nearest_nodes, node_to_peer_id, _is_refresh=_is_refresh)
|
|
|
|
|
|
- asyncio.create_task(traverse_dht(
|
|
|
- queries=list(unfinished_key_ids), initial_nodes=list(node_to_peer_id), beam_size=beam_size,
|
|
|
- num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
|
|
|
- get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
|
|
|
- found_callback=found_callback, await_all_tasks=False))
|
|
|
+ asyncio.create_task(
|
|
|
+ traverse_dht(
|
|
|
+ queries=list(unfinished_key_ids),
|
|
|
+ initial_nodes=list(node_to_peer_id),
|
|
|
+ beam_size=beam_size,
|
|
|
+ num_workers=num_workers,
|
|
|
+ queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
|
|
|
+ get_neighbors=get_neighbors,
|
|
|
+ visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
|
|
|
+ found_callback=found_callback,
|
|
|
+ await_all_tasks=False,
|
|
|
+ )
|
|
|
+ )
|
|
|
|
|
|
if return_futures:
|
|
|
return {key_id: search_result.future for key_id, search_result in search_results.items()}
|
|
@@ -552,14 +679,16 @@ class DHTNode:
|
|
|
pending_requests.discard(finished)
|
|
|
|
|
|
async def _call_find_with_blacklist(self, peer_id: PeerID, keys: Collection[DHTID]):
|
|
|
- """ same as call_find, but skip if :peer_id: is blacklisted; also exclude blacklisted neighbors from result """
|
|
|
+ """same as call_find, but skip if :peer_id: is blacklisted; also exclude blacklisted neighbors from result"""
|
|
|
if peer_id in self.blacklist:
|
|
|
return None
|
|
|
response = await self.protocol.call_find(peer_id, keys)
|
|
|
if response:
|
|
|
self.blacklist.register_success(peer_id)
|
|
|
- return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
|
|
|
- for key, (maybe_value, nearest_peers) in response.items()}
|
|
|
+ return {
|
|
|
+ key: (maybe_value, self._filter_blacklisted(nearest_peers))
|
|
|
+ for key, (maybe_value, nearest_peers) in response.items()
|
|
|
+ }
|
|
|
else:
|
|
|
self.blacklist.register_failure(peer_id)
|
|
|
return None
|
|
@@ -568,13 +697,13 @@ class DHTNode:
|
|
|
return {peer: peer_id for peer, peer_id in peer_ids.items() if peer_id not in self.blacklist}
|
|
|
|
|
|
def _trigger_cache_refresh(self, search: _SearchState):
|
|
|
- """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
|
|
|
+ """Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused)"""
|
|
|
if search.found_something and search.source_node_id == self.node_id:
|
|
|
if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache:
|
|
|
self._schedule_for_refresh(search.key_id, search.expiration_time - self.cache_refresh_before_expiry)
|
|
|
|
|
|
def _schedule_for_refresh(self, key_id: DHTID, refresh_time: DHTExpiration):
|
|
|
- """ Add key to a refresh queue, refresh at :refresh_time: or later """
|
|
|
+ """Add key to a refresh queue, refresh at :refresh_time: or later"""
|
|
|
if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
|
|
|
self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
|
|
|
logger.debug("Spawned cache refresh task.")
|
|
@@ -584,7 +713,7 @@ class DHTNode:
|
|
|
self.cache_refresh_queue.store(key_id, value=refresh_time, expiration_time=refresh_time)
|
|
|
|
|
|
async def _refresh_stale_cache_entries(self):
|
|
|
- """ periodically refresh keys near-expired keys that were accessed at least once during previous lifetime """
|
|
|
+ """periodically refresh keys near-expired keys that were accessed at least once during previous lifetime"""
|
|
|
while self.is_alive:
|
|
|
while len(self.cache_refresh_queue) == 0:
|
|
|
await self.cache_refresh_evt.wait()
|
|
@@ -619,12 +748,17 @@ class DHTNode:
|
|
|
sufficient_expiration_time = max_expiration_time + self.cache_refresh_before_expiry + 1
|
|
|
await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
|
|
|
|
|
|
- def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
|
|
|
- node_to_peer_id: Dict[DHTID, PeerID], _is_refresh: bool = False):
|
|
|
- """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
|
|
|
+ def _cache_new_result(
|
|
|
+ self,
|
|
|
+ search: _SearchState,
|
|
|
+ nearest_nodes: List[DHTID],
|
|
|
+ node_to_peer_id: Dict[DHTID, PeerID],
|
|
|
+ _is_refresh: bool = False,
|
|
|
+ ):
|
|
|
+ """after key_id is found, update cache according to caching policy. used internally in get and get_many"""
|
|
|
if search.found_something:
|
|
|
- _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
|
|
|
- _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float('inf'))
|
|
|
+ _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float("inf"))
|
|
|
+ _, cache_expiration_time = self.protocol.cache.get(search.key_id) or (None, -float("inf"))
|
|
|
|
|
|
if search.expiration_time > max(storage_expiration_time, cache_expiration_time):
|
|
|
if self.cache_locally or _is_refresh:
|
|
@@ -634,20 +768,27 @@ class DHTNode:
|
|
|
for node_id in nearest_nodes:
|
|
|
if node_id == search.source_node_id:
|
|
|
continue
|
|
|
- asyncio.create_task(self.protocol.call_store(
|
|
|
- node_to_peer_id[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
|
|
|
- in_cache=True))
|
|
|
+ asyncio.create_task(
|
|
|
+ self.protocol.call_store(
|
|
|
+ node_to_peer_id[node_id],
|
|
|
+ [search.key_id],
|
|
|
+ [search.binary_value],
|
|
|
+ [search.expiration_time],
|
|
|
+ in_cache=True,
|
|
|
+ )
|
|
|
+ )
|
|
|
num_cached_nodes += 1
|
|
|
if num_cached_nodes >= self.cache_nearest:
|
|
|
break
|
|
|
|
|
|
async def _refresh_routing_table(self, *, period: Optional[float]) -> None:
|
|
|
- """ Tries to find new nodes for buckets that were unused for more than self.staleness_timeout """
|
|
|
+ """Tries to find new nodes for buckets that were unused for more than self.staleness_timeout"""
|
|
|
while self.is_alive and period is not None: # if None run once, otherwise run forever
|
|
|
refresh_time = get_dht_time()
|
|
|
staleness_threshold = refresh_time - period
|
|
|
- stale_buckets = [bucket for bucket in self.protocol.routing_table.buckets
|
|
|
- if bucket.last_updated < staleness_threshold]
|
|
|
+ stale_buckets = [
|
|
|
+ bucket for bucket in self.protocol.routing_table.buckets if bucket.last_updated < staleness_threshold
|
|
|
+ ]
|
|
|
for bucket in stale_buckets:
|
|
|
refresh_id = DHTID(random.randint(bucket.lower, bucket.upper - 1))
|
|
|
await self.find_nearest_nodes(refresh_id)
|
|
@@ -660,7 +801,8 @@ class DHTNode:
|
|
|
|
|
|
@dataclass(init=True, repr=True, frozen=False, order=False)
|
|
|
class _SearchState:
|
|
|
- """ A helper class that stores current-best GET results with metadata """
|
|
|
+ """A helper class that stores current-best GET results with metadata"""
|
|
|
+
|
|
|
key_id: DHTID
|
|
|
sufficient_expiration_time: DHTExpiration
|
|
|
binary_value: Optional[Union[BinaryDHTValue, DictionaryDHTValue]] = None
|
|
@@ -670,25 +812,28 @@ class _SearchState:
|
|
|
serializer: Type[SerializerBase] = MSGPackSerializer
|
|
|
record_validator: Optional[RecordValidatorBase] = None
|
|
|
|
|
|
- def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
|
|
|
- source_node_id: Optional[DHTID]):
|
|
|
+ def add_candidate(
|
|
|
+ self,
|
|
|
+ candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
|
|
|
+ source_node_id: Optional[DHTID],
|
|
|
+ ):
|
|
|
if self.finished or candidate is None:
|
|
|
return
|
|
|
elif isinstance(candidate.value, DictionaryDHTValue) and isinstance(self.binary_value, DictionaryDHTValue):
|
|
|
self.binary_value.maxsize = max(self.binary_value.maxsize, candidate.value.maxsize)
|
|
|
for subkey, subentry in candidate.value.items():
|
|
|
self.binary_value.store(subkey, subentry.value, subentry.expiration_time)
|
|
|
- elif candidate.expiration_time > (self.expiration_time or float('-inf')):
|
|
|
+ elif candidate.expiration_time > (self.expiration_time or float("-inf")):
|
|
|
self.binary_value = candidate.value
|
|
|
|
|
|
- if candidate.expiration_time > (self.expiration_time or float('-inf')):
|
|
|
+ if candidate.expiration_time > (self.expiration_time or float("-inf")):
|
|
|
self.expiration_time = candidate.expiration_time
|
|
|
self.source_node_id = source_node_id
|
|
|
if self.expiration_time >= self.sufficient_expiration_time:
|
|
|
self.finish_search()
|
|
|
|
|
|
def add_done_callback(self, callback: Callable[[_SearchState], Any]):
|
|
|
- """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """
|
|
|
+ """Add callback that will be called when _SearchState is done (found OR cancelled by user)"""
|
|
|
self.future.add_done_callback(lambda _future: callback(self))
|
|
|
|
|
|
def finish_search(self):
|
|
@@ -699,30 +844,30 @@ class _SearchState:
|
|
|
elif isinstance(self.binary_value, BinaryDHTValue):
|
|
|
value_bytes = self.binary_value
|
|
|
if self.record_validator is not None:
|
|
|
- record = DHTRecord(self.key_id.to_bytes(), DHTProtocol.IS_REGULAR_VALUE,
|
|
|
- value_bytes, self.expiration_time)
|
|
|
+ record = DHTRecord(
|
|
|
+ self.key_id.to_bytes(), DHTProtocol.IS_REGULAR_VALUE, value_bytes, self.expiration_time
|
|
|
+ )
|
|
|
value_bytes = self.record_validator.strip_value(record)
|
|
|
|
|
|
- self.future.set_result(
|
|
|
- ValueWithExpiration(self.serializer.loads(value_bytes), self.expiration_time))
|
|
|
+ self.future.set_result(ValueWithExpiration(self.serializer.loads(value_bytes), self.expiration_time))
|
|
|
elif isinstance(self.binary_value, DictionaryDHTValue):
|
|
|
dict_with_subkeys = {}
|
|
|
for subkey, (value_bytes, item_expiration_time) in self.binary_value.items():
|
|
|
if self.record_validator is not None:
|
|
|
subkey_bytes = self.serializer.dumps(subkey)
|
|
|
- record = DHTRecord(self.key_id.to_bytes(), subkey_bytes,
|
|
|
- value_bytes, item_expiration_time)
|
|
|
+ record = DHTRecord(self.key_id.to_bytes(), subkey_bytes, value_bytes, item_expiration_time)
|
|
|
value_bytes = self.record_validator.strip_value(record)
|
|
|
|
|
|
dict_with_subkeys[subkey] = ValueWithExpiration(
|
|
|
- self.serializer.loads(value_bytes), item_expiration_time)
|
|
|
+ self.serializer.loads(value_bytes), item_expiration_time
|
|
|
+ )
|
|
|
self.future.set_result(ValueWithExpiration(dict_with_subkeys, self.expiration_time))
|
|
|
else:
|
|
|
logger.error(f"Invalid value type: {type(self.binary_value)}")
|
|
|
|
|
|
@property
|
|
|
def found_something(self) -> bool:
|
|
|
- """ Whether or not we have found at least some value, regardless of its expiration time """
|
|
|
+ """Whether or not we have found at least some value, regardless of its expiration time"""
|
|
|
return self.expiration_time is not None
|
|
|
|
|
|
@property
|
|
@@ -730,7 +875,7 @@ class _SearchState:
|
|
|
return self.future.done()
|
|
|
|
|
|
def __lt__(self, other: _SearchState):
|
|
|
- """ _SearchState instances will be sorted by their target expiration time """
|
|
|
+ """_SearchState instances will be sorted by their target expiration time"""
|
|
|
return self.sufficient_expiration_time < other.sufficient_expiration_time
|
|
|
|
|
|
def __hash__(self):
|
|
@@ -750,22 +895,24 @@ class Blacklist:
|
|
|
self.ban_counter = Counter()
|
|
|
|
|
|
def register_failure(self, peer: PeerID):
|
|
|
- """ peer failed to respond, add him to blacklist or increase his downtime """
|
|
|
+ """peer failed to respond, add him to blacklist or increase his downtime"""
|
|
|
if peer not in self.banned_peers and self.base_time > 0:
|
|
|
ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
|
|
|
self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
|
|
|
self.ban_counter[peer] += 1
|
|
|
|
|
|
def register_success(self, peer):
|
|
|
- """ peer responded successfully, remove him from blacklist and reset his ban time """
|
|
|
+ """peer responded successfully, remove him from blacklist and reset his ban time"""
|
|
|
del self.banned_peers[peer], self.ban_counter[peer]
|
|
|
|
|
|
def __contains__(self, peer: PeerID) -> bool:
|
|
|
return peer in self.banned_peers
|
|
|
|
|
|
def __repr__(self):
|
|
|
- return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
|
|
|
- f"banned_peers={len(self.banned_peers)})"
|
|
|
+ return (
|
|
|
+ f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, "
|
|
|
+ f"banned_peers={len(self.banned_peers)})"
|
|
|
+ )
|
|
|
|
|
|
def clear(self):
|
|
|
self.banned_peers.clear()
|
|
@@ -773,5 +920,6 @@ class Blacklist:
|
|
|
|
|
|
|
|
|
class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
|
|
|
- """ a queue of keys scheduled for refresh in future, used in DHTNode """
|
|
|
+ """a queue of keys scheduled for refresh in future, used in DHTNode"""
|
|
|
+
|
|
|
frozen = True
|