|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
|
|
|
import asyncio
|
|
|
import random
|
|
|
-from collections import defaultdict
|
|
|
+from collections import defaultdict, Counter
|
|
|
from dataclasses import dataclass, field
|
|
|
from functools import partial
|
|
|
from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
|
|
@@ -11,9 +11,10 @@ from sortedcontainers import SortedSet
|
|
|
|
|
|
from hivemind.dht.protocol import DHTProtocol
|
|
|
from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
|
|
|
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
|
|
|
+from hivemind.dht.storage import DictionaryDHTValue
|
|
|
from hivemind.dht.traverse import traverse_dht
|
|
|
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, ValueWithExpiration
|
|
|
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
|
|
|
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -67,6 +68,7 @@ class DHTNode:
|
|
|
chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
|
|
|
cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
|
|
|
cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
|
|
|
+ blacklist: Blacklist
|
|
|
# fmt:on
|
|
|
|
|
|
@classmethod
|
|
@@ -76,7 +78,9 @@ class DHTNode:
|
|
|
wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
|
|
|
cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
|
|
|
cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
|
|
|
- listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
|
|
|
+ blacklist_time: float = 5.0, backoff_rate: float = 2.0,
|
|
|
+ listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
|
|
|
+ validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
|
|
|
"""
|
|
|
:param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
|
|
|
:param initial_peers: connects to these peers to populate routing table, defaults to no peers
|
|
@@ -102,9 +106,14 @@ class DHTNode:
|
|
|
all concurrent get requests for the same key will reuse the procedure that is currently in progress
|
|
|
:param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
|
|
|
:param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
|
|
|
+ :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
|
|
|
+ :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
|
|
|
+ :param validate: if True, use initial peers to validate that this node is accessible and synchronized
|
|
|
+ :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
|
|
|
:param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
|
|
|
if False, this node will refuse any incoming request, effectively being only a "client"
|
|
|
:param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
|
|
|
+ :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
|
|
|
:param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
|
|
|
see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
|
|
|
:param kwargs: extra parameters used in grpc.aio.server
|
|
@@ -121,19 +130,21 @@ class DHTNode:
|
|
|
self.refresh_timeout = refresh_timeout
|
|
|
self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
|
|
|
self.cache_refresh_before_expiry = cache_refresh_before_expiry
|
|
|
+ self.blacklist = Blacklist(blacklist_time, backoff_rate)
|
|
|
self.cache_refresh_queue = CacheRefreshQueue()
|
|
|
self.cache_refresh_evt = asyncio.Event()
|
|
|
self.cache_refresh_task = None
|
|
|
|
|
|
self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
|
|
|
- parallel_rpc, cache_size, listen, listen_on, **kwargs)
|
|
|
+ parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
|
|
|
self.port = self.protocol.port
|
|
|
|
|
|
if initial_peers:
|
|
|
# stage 1: ping initial_peers, add each other to the routing table
|
|
|
bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
|
|
|
start_time = get_dht_time()
|
|
|
- ping_tasks = map(self.protocol.call_ping, initial_peers)
|
|
|
+ ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
|
|
|
+ for peer in initial_peers)
|
|
|
finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
|
|
# stage 2: gather remaining peers (those who respond within bootstrap_timeout)
|
|
@@ -147,6 +158,10 @@ class DHTNode:
|
|
|
if not finished_pings:
|
|
|
logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
|
|
|
|
|
|
+ if strict:
|
|
|
+ for task in asyncio.as_completed(finished_pings):
|
|
|
+ await task # propagate exceptions
|
|
|
+
|
|
|
# stage 3: traverse dht to find my own nearest neighbors and populate the routing table
|
|
|
# ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
|
|
|
# note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
|
|
@@ -192,11 +207,11 @@ class DHTNode:
|
|
|
if node_to_endpoint is None:
|
|
|
node_to_endpoint: Dict[DHTID, Endpoint] = dict()
|
|
|
for query in queries:
|
|
|
- node_to_endpoint.update(
|
|
|
- self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
|
|
|
+ neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
|
|
|
+ node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
|
|
|
|
|
|
async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
|
|
|
- response = await self.protocol.call_find(node_to_endpoint[peer], queries)
|
|
|
+ response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
|
|
|
if not response:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
@@ -433,7 +448,7 @@ class DHTNode:
|
|
|
# V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
|
|
|
async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
|
|
|
queries = list(queries)
|
|
|
- response = await self.protocol.call_find(node_to_endpoint[peer], queries)
|
|
|
+ response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
|
|
|
if not response:
|
|
|
return {query: ([], False) for query in queries}
|
|
|
|
|
@@ -480,6 +495,22 @@ class DHTNode:
|
|
|
else:
|
|
|
pending_requests.discard(finished)
|
|
|
|
|
|
+ async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
|
|
|
+ """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
|
|
|
+ if endpoint in self.blacklist:
|
|
|
+ return None
|
|
|
+ response = await self.protocol.call_find(endpoint, keys)
|
|
|
+ if response:
|
|
|
+ self.blacklist.register_success(endpoint)
|
|
|
+ return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
|
|
|
+ for key, (maybe_value, nearest_peers) in response.items()}
|
|
|
+ else:
|
|
|
+ self.blacklist.register_failure(endpoint)
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
|
|
|
+ return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
|
|
|
+
|
|
|
def _trigger_cache_refresh(self, search: _SearchState):
|
|
|
""" Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
|
|
|
if search.found_something and search.source_node_id == self.node_id:
|
|
@@ -629,3 +660,42 @@ class _SearchState:
|
|
|
|
|
|
def __hash__(self):
|
|
|
return hash(self.key_id)
|
|
|
+
|
|
|
+
|
|
|
+class Blacklist:
|
|
|
+ """
|
|
|
+ A temporary blacklist of non-responding peers with exponential backoff policy
|
|
|
+ :param base_time: peers are suspended for this many seconds by default
|
|
|
+ :param backoff_rate: suspension time increases by this factor after each successive failure
|
|
|
+ """
|
|
|
+ def __init__(self, base_time: float, backoff_rate: float, **kwargs):
|
|
|
+ self.base_time, self.backoff = base_time, backoff_rate
|
|
|
+ self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
|
|
|
+ self.ban_counter = Counter()
|
|
|
+
|
|
|
+ def register_failure(self, peer: Endpoint):
|
|
|
+ """ peer failed to respond, add him to blacklist or increase his downtime """
|
|
|
+ if peer not in self.banned_peers and self.base_time > 0:
|
|
|
+ ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
|
|
|
+ self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
|
|
|
+ self.ban_counter[peer] += 1
|
|
|
+
|
|
|
+ def register_success(self, peer):
|
|
|
+ """ peer responded successfully, remove him from blacklist and reset his ban time """
|
|
|
+ del self.banned_peers[peer], self.ban_counter[peer]
|
|
|
+
|
|
|
+ def __contains__(self, peer: Endpoint) -> bool:
|
|
|
+ return peer in self.banned_peers
|
|
|
+
|
|
|
+ def __repr__(self):
|
|
|
+ return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
|
|
|
+ f"banned_peers={len(self.banned_peers)})"
|
|
|
+
|
|
|
+ def clear(self):
|
|
|
+ self.banned_peers.clear()
|
|
|
+ self.ban_counter.clear()
|
|
|
+
|
|
|
+
|
|
|
+class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
|
|
|
+ """ a queue of keys scheduled for refresh in future, used in DHTNode """
|
|
|
+ frozen = True
|