Browse Source

Add DHT peer validation, add DHT.get_visible_address, add blacklist for unresponsive peers (#137)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 năm trước cách đây
mục cha
commit
c36b5b1a9b

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.utils import *
 
 
-__version__ = '0.8.22'
+__version__ = '0.8.23'

+ 18 - 19
hivemind/client/averaging/__init__.py

@@ -47,6 +47,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
+      towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
@@ -66,8 +68,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
-                 request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
+                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
+                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
@@ -95,7 +97,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
             chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
             chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
-        self.allreduce_timeout = allreduce_timeout
+        self.averaging_alpha, self.allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -193,8 +195,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 group_id = allreduce_group.group_id
                 group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
-                averaging_deltas = await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
-                update_ok = await loop.run_in_executor(None, lambda: self.update_tensors(averaging_deltas, add=True))
+                await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
+                update_ok = await loop.run_in_executor(None, self.update_tensors, allreduce_group)
 
 
                 # averaging is finished, exit the loop
                 # averaging is finished, exit the loop
                 future.set_result(update_ok)
                 future.set_result(update_ok)
@@ -213,23 +215,20 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 _ = self._running_groups.pop(group_id, None)
                 _ = self._running_groups.pop(group_id, None)
                 self._pending_group_assembled.set()
                 self._pending_group_assembled.set()
 
 
-    def update_tensors(self, tensors: Sequence[torch.Tensor], *, add: bool = False) -> bool:
+    def update_tensors(self, allreduce_group: AllReduceRunner) -> bool:
         """
         """
-        Set or change the values of self.averaged_tensors.
+        a private (extendable) method that applies changes from a finished allreduce to local tensors
 
 
-        :param tensors: list/tuple of tensors of same shape as self.averaged_tensors
-        :param add: if True, add tensors to self.averaged_tensors in-place
-          by default, simply write the values of :tensors: to self.averaged_tensors
-        :note: if there may be updates running in background, it is recommended to use add=True
+        :return: True on success, False on failure
         """
         """
-        assert len(tensors) == len(self._averaged_tensors)
-        with torch.no_grad(), self.lock_averaged_tensors:
-            for tensor, update in zip(self._averaged_tensors, tensors):
-                if add:
-                    tensor += update
-                else:
-                    tensor[...] = update
-        return True
+        assert allreduce_group.return_deltas and allreduce_group.future.done()
+        averaging_deltas = allreduce_group.future.result()
+
+        with torch.no_grad(), self.get_tensors() as local_tensors:
+            assert len(local_tensors) == len(self._averaged_tensors)
+            for tensor, update in zip(local_tensors, averaging_deltas):
+                tensor.add_(update, alpha=self.averaging_alpha)
+            return True
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:

+ 54 - 2
hivemind/dht/__init__.py

@@ -21,13 +21,12 @@ from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
 
-import uvloop
 from numpy import nextafter
 from numpy import nextafter
 
 
 from hivemind.client import RemoteExpert
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, get_logger, switch_to_uvloop
+from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -181,6 +180,59 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
         return self._port.value if self._port.value != 0 else None
 
 
+    def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
+        """
+        Get this machine's visible address by requesting other peers or using pre-specified network addresses.
+        If no parameters are specified, this function will check for manual endpoint; if unavailable, ask 1 random peer.
+
+        :param num_peers: if specified, ask multiple peers and check that they perceive the same endpoint
+        :param peers: if specified, ask these exact peers instead of choosing random known peers
+        :note: if this node has no known peers in routing table, one must specify :peers: manually
+        """
+        assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
+        assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        return future.result()
+
+    async def _get_visible_address(self, node: DHTNode, num_peers: Optional[int], peers: Sequence[Endpoint],
+                                   future: Optional[MPFuture]):
+        if not peers and (num_peers or not node.protocol.node_info.endpoint):
+            # if we can't resolve the endpoint locally, ask one random peer
+            peers_and_endpoints = node.protocol.routing_table.get_nearest_neighbors(
+                DHTID.generate(), num_peers or 1, exclude=node.node_id)
+            peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
+
+        chosen_address = None
+        if peers:
+            possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
+                node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
+
+            for endpoint in possible_endpoints:
+                if endpoint is None:
+                    continue
+                address = strip_port(endpoint)
+                if chosen_address is not None and address != chosen_address:
+                    logger.warning("At least two peers returned different visible addresses for this node:"
+                                   f"{address} and {chosen_address} (keeping the former one)")
+                else:
+                    chosen_address = address
+
+            if chosen_address is None:
+                logger.warning(f"None of the selected peers responded with an address ({peers})")
+
+        if node.protocol.node_info.endpoint:
+            address = strip_port(node.protocol.node_info.endpoint)
+            if chosen_address is not None and address != chosen_address:
+                logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
+            chosen_address = chosen_address or address
+
+        if chosen_address:
+            future.set_result(chosen_address)
+        else:
+            future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
+                                            f" Please ensure the node is connected or specify peers=... manually."))
+
     def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
     def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
                         timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
                         timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
         """
         """

+ 80 - 10
hivemind/dht/node.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 
 import asyncio
 import asyncio
 import random
 import random
-from collections import defaultdict
+from collections import defaultdict, Counter
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from functools import partial
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
@@ -11,9 +11,10 @@ from sortedcontainers import SortedSet
 
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
+from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, ValueWithExpiration
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -67,6 +68,7 @@ class DHTNode:
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
+    blacklist: Blacklist
     # fmt:on
     # fmt:on
 
 
     @classmethod
     @classmethod
@@ -76,7 +78,9 @@ class DHTNode:
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
-            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            blacklist_time: float = 5.0, backoff_rate: float = 2.0,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
+            validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
         """
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -102,9 +106,14 @@ class DHTNode:
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
+        :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
+        :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
+        :param validate: if True, use initial peers to validate that this node is accessible and synchronized
+        :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+        :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
         :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
         :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
         :param kwargs: extra parameters used in grpc.aio.server
@@ -121,19 +130,21 @@ class DHTNode:
         self.refresh_timeout = refresh_timeout
         self.refresh_timeout = refresh_timeout
         self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
         self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
+        self.blacklist = Blacklist(blacklist_time, backoff_rate)
         self.cache_refresh_queue = CacheRefreshQueue()
         self.cache_refresh_queue = CacheRefreshQueue()
         self.cache_refresh_evt = asyncio.Event()
         self.cache_refresh_evt = asyncio.Event()
         self.cache_refresh_task = None
         self.cache_refresh_task = None
 
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 parallel_rpc, cache_size, listen, listen_on, **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
         self.port = self.protocol.port
         self.port = self.protocol.port
 
 
         if initial_peers:
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
             start_time = get_dht_time()
-            ping_tasks = map(self.protocol.call_ping, initial_peers)
+            ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                             for peer in initial_peers)
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
 
 
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
@@ -147,6 +158,10 @@ class DHTNode:
             if not finished_pings:
             if not finished_pings:
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
 
+            if strict:
+                for task in asyncio.as_completed(finished_pings):
+                    await task  # propagate exceptions
+
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
@@ -192,11 +207,11 @@ class DHTNode:
         if node_to_endpoint is None:
         if node_to_endpoint is None:
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             for query in queries:
             for query in queries:
-                node_to_endpoint.update(
-                    self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
+                neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
+                node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
 
 
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
-            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
             if not response:
             if not response:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
@@ -433,7 +448,7 @@ class DHTNode:
         # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
         # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             queries = list(queries)
             queries = list(queries)
-            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
             if not response:
             if not response:
                 return {query: ([], False) for query in queries}
                 return {query: ([], False) for query in queries}
 
 
@@ -480,6 +495,22 @@ class DHTNode:
         else:
         else:
             pending_requests.discard(finished)
             pending_requests.discard(finished)
 
 
+    async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
+        """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
+        if endpoint in self.blacklist:
+            return None
+        response = await self.protocol.call_find(endpoint, keys)
+        if response:
+            self.blacklist.register_success(endpoint)
+            return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
+                    for key, (maybe_value, nearest_peers) in response.items()}
+        else:
+            self.blacklist.register_failure(endpoint)
+            return None
+
+    def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
+        return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
+
     def _trigger_cache_refresh(self, search: _SearchState):
     def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
         if search.found_something and search.source_node_id == self.node_id:
         if search.found_something and search.source_node_id == self.node_id:
@@ -629,3 +660,42 @@ class _SearchState:
 
 
     def __hash__(self):
     def __hash__(self):
         return hash(self.key_id)
         return hash(self.key_id)
+
+
+class Blacklist:
+    """
+    A temporary blacklist of non-responding peers with exponential backoff policy
+    :param base_time: peers are suspended for this many seconds by default
+    :param backoff_rate: suspension time increases by this factor after each successive failure
+    """
+    def __init__(self, base_time: float, backoff_rate: float, **kwargs):
+        self.base_time, self.backoff = base_time, backoff_rate
+        self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
+        self.ban_counter = Counter()
+
+    def register_failure(self, peer: Endpoint):
+        """ peer failed to respond, add him to blacklist or increase his downtime """
+        if peer not in self.banned_peers and self.base_time > 0:
+            ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
+            self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
+            self.ban_counter[peer] += 1
+
+    def register_success(self, peer):
+        """ peer responded successfully, remove him from blacklist and reset his ban time """
+        del self.banned_peers[peer], self.ban_counter[peer]
+
+    def __contains__(self, peer: Endpoint) -> bool:
+        return peer in self.banned_peers
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
+               f"banned_peers={len(self.banned_peers)})"
+
+    def clear(self):
+        self.banned_peers.clear()
+        self.ban_counter.clear()
+
+
+class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
+    """ a queue of keys scheduled for refresh in future, used in DHTNode """
+    frozen = True

+ 77 - 23
hivemind/dht/protocol.py

@@ -10,7 +10,7 @@ from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpirat
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -18,7 +18,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Sequence[Tuple[str, Any]]; server: grpc.aio.Server
+    channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
     # fmt:on
 
 
@@ -28,7 +28,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
     @classmethod
     @classmethod
     async def create(
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
-            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
+            listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
@@ -54,10 +55,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
 
 
-            found_port = self.server.add_insecure_port(listen_on)
-            assert found_port != 0, f"Failed to listen to {listen_on}"
-            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=found_port)
-            self.port = found_port
+            self.port = self.server.add_insecure_port(listen_on)
+            assert self.port != 0, f"Failed to listen to {listen_on}"
+            if endpoint is not None and endpoint.endswith('*'):
+                endpoint = replace_port(endpoint, self.port)
+            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
+                                              endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
             await self.server.start()
             await self.server.start()
         else:  # not listening to incoming requests, client-only mode
         else:  # not listening to incoming requests, client-only mode
             # note: use empty node_info so peers won't add you to their routing tables
             # note: use empty node_info so peers won't add you to their routing tables
@@ -83,32 +86,78 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ get a DHTStub that sends requests to a given peer """
         """ get a DHTStub that sends requests to a given peer """
         return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
         return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
 
 
-    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
+    async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         """
         Get peer's node id and add him to the routing table. If peer doesn't respond, return None
         Get peer's node id and add him to the routing table. If peer doesn't respond, return None
         :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
         :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+        :param validate: if True, validates that node's endpoint is available
+        :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
         :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
         :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
 
 
         :return: node's DHTID, if peer responded and decided to send his node_id
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
         """
         try:
         try:
             async with self.rpc_semaphore:
             async with self.rpc_semaphore:
-                peer_info = await self._get_dht_stub(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+                ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
+                time_requested = get_dht_time()
+                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
+                time_responded = get_dht_time()
         except grpc.aio.AioRpcError as error:
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
-            peer_info = None
-        responded = bool(peer_info and peer_info.node_id)
-        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
+            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            response = None
+        responded = bool(response and response.peer and response.peer.node_id)
+
+        if responded and validate:
+            try:
+                if self.server is not None and not response.available:
+                    raise ValidationError(f"peer {peer} couldn't access this node at {response.sender_endpoint} .")
+
+                if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
+                    if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
+                            response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
+                        raise ValidationError(f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
+                                              f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})")
+            except ValidationError as e:
+                if strict:
+                    raise
+                else:
+                    logger.warning(repr(e))
+
+        peer_id = DHTID.from_bytes(response.peer.node_id) if responded else None
         asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
         asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
         return peer_id
         return peer_id
 
 
-    async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext):
+    async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]:
+        """ ask this peer how it perceives this node's outgoing request address """
+        try:
+            async with self.rpc_semaphore:
+                ping_request = dht_pb2.PingRequest(peer=None, validate=False)
+                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
+                if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
+                    return response.sender_endpoint
+        except grpc.aio.AioRpcError as error:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+
+    async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
         """ Some node wants us to add it to our routing table. """
         """ Some node wants us to add it to our routing table. """
-        if peer_info.node_id and peer_info.rpc_port:
-            sender_id = DHTID.from_bytes(peer_info.node_id)
-            rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port)
-            asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
-        return self.node_info
+        response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
+                                        dht_time=get_dht_time(), available=False)
+
+        if request.peer and request.peer.node_id and request.peer.rpc_port:
+            sender_id = DHTID.from_bytes(request.peer.node_id)
+            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
+                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
+            else:
+                sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
+
+            response.sender_endpoint = sender_endpoint
+            if request.validate:
+                response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
+
+            asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
+                                                          responded=response.available or not request.validate))
+
+        return response
 
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
                          values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
                          values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
@@ -161,14 +210,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
             return response.store_ok
         except grpc.aio.AioRpcError as error:
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
             return None
 
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
-            asyncio.create_task(self.rpc_ping(request.peer, context))
+            asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         keys = map(DHTID.from_bytes, request.keys)
         keys = map(DHTID.from_bytes, request.keys)
@@ -225,7 +274,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
 
             return output
             return output
         except grpc.aio.AioRpcError as error:
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
+            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
@@ -234,7 +283,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         """
         """
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
-            asyncio.create_task(self.rpc_ping(request.peer, context))
+            asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
+
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_item = self.storage.get(key_id)
             maybe_item = self.storage.get(key_id)
@@ -294,3 +344,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:  # we sent outgoing request and peer did not respond
         else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
                 del self.routing_table[node_id]
+
+
+class ValidationError(Exception):
+    """ This exception is thrown if DHT node didn't pass validation by other nodes. """

+ 0 - 3
hivemind/dht/storage.py

@@ -65,6 +65,3 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
             return False
             return False
 
 
 
 
-class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
-    """ a queue of keys scheduled for refresh in future, used in DHTNode """
-    frozen = True

+ 14 - 1
hivemind/proto/dht.proto

@@ -5,7 +5,7 @@ syntax = "proto3";
 
 
 service DHT {
 service DHT {
   // find out recipient's DHTID and possibly update its routing table
   // find out recipient's DHTID and possibly update its routing table
-  rpc rpc_ping(NodeInfo) returns (NodeInfo);
+  rpc rpc_ping(PingRequest) returns (PingResponse);
 
 
   // request a node to store one or multiple data items (key - value - expiration)
   // request a node to store one or multiple data items (key - value - expiration)
   rpc rpc_store(StoreRequest) returns (StoreResponse);
   rpc rpc_store(StoreRequest) returns (StoreResponse);
@@ -19,6 +19,19 @@ message NodeInfo {
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
   bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
   bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
   int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
   int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
+  string endpoint = 3;                 // (optional) node's preferred return address
+}
+
+message PingRequest {
+  NodeInfo peer = 1;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  bool validate = 2;                   // set to True if sender wants to validate that he is accessible and synchronized
+}
+
+message PingResponse {
+  NodeInfo peer = 1;                   // respondent's node id, for you to update routing table
+  string sender_endpoint = 2;          // echo sender's visible endpoint - used to infer his ip address
+  double dht_time = 3;                 // recipient's local DHT time - used to soft-synchronize peers
+  bool available = 4;                  // if validate = True, this flag asserts that the sender is available for ping
 }
 }
 
 
 message StoreRequest {
 message StoreRequest {

+ 6 - 0
hivemind/utils/networking.py

@@ -21,6 +21,12 @@ def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
     return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
     return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
 
 
 
 
+def strip_port(endpoint: Endpoint) -> Hostname:
+    """ Removes port from the end of endpoint. If port is not specified, does nothing """
+    maybe_port = endpoint[endpoint.rindex(':') + 1:]
+    return endpoint[:endpoint.rindex(':')] if maybe_port.isdigit() or maybe_port == '*' else endpoint
+
+
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:
     try:

+ 1 - 1
tests/benchmark_dht.py

@@ -34,7 +34,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     random.shuffle(expert_uids)
     random.shuffle(expert_uids)
 
 
-    print(f"Storing peers to dht in batches of {expert_batch_size}...")
+    print(f"Storing experts to dht in batches of {expert_batch_size}...")
     successful_stores = total_stores = total_store_time = 0
     successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
     benchmark_started = time.perf_counter()
     endpoints = []
     endpoints = []

+ 18 - 3
tests/test_dht_experts.py → tests/test_dht.py

@@ -2,10 +2,9 @@ import random
 import numpy as np
 import numpy as np
 import pytest
 import pytest
 import asyncio
 import asyncio
-import multiprocessing as mp
 
 
 import hivemind
 import hivemind
-from hivemind import LOCALHOST, UidEndpoint
+from hivemind import LOCALHOST, UidEndpoint, strip_port
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -21,7 +20,7 @@ def test_store_get_experts():
     expert_uids = [f"my_expert.{i}" for i in range(110)]
     expert_uids = [f"my_expert.{i}" for i in range(110)]
     batch_size = 10
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
     for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
+        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
 
 
     found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
     found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
@@ -37,6 +36,22 @@ def test_store_get_experts():
         peer.shutdown()
         peer.shutdown()
 
 
 
 
+@pytest.mark.forked
+def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
+    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
+    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
+    assert addr in node3.get_visible_address(num_peers=2)
+
+    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    with pytest.raises(ValueError):
+        node4.get_visible_address()
+    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
+
+    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
+    assert node5.get_visible_address() == strip_port(dummy_endpoint)
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
                      grid_dims=(32, 32, 32)):
                      grid_dims=(32, 32, 32)):

+ 46 - 2
tests/test_dht_node.py

@@ -9,9 +9,9 @@ import pytest
 import hivemind
 import hivemind
 from typing import List, Dict
 from typing import List, Dict
 
 
-from hivemind import get_dht_time
+from hivemind import get_dht_time, replace_port
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
-from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 
 
 
 
@@ -104,6 +104,8 @@ def test_dht_protocol():
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
 
+        assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
+
         if listen:
         if listen:
             loop.run_until_complete(protocol.shutdown())
             loop.run_until_complete(protocol.shutdown())
         print("DHTProtocol test finished successfully!")
         print("DHTProtocol test finished successfully!")
@@ -390,3 +392,45 @@ async def test_dhtnode_reuse_get():
     assert (await futures1['k1'])[0] == 123
     assert (await futures1['k1'])[0] == 123
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_blacklist():
+    node1 = await hivemind.DHTNode.create(blacklist_time=999)
+    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+
+    assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
+    assert len(node2.blacklist.ban_counter) == 0
+
+    await node3.shutdown()
+    await node4.shutdown()
+
+    assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
+
+    assert len(node2.blacklist.ban_counter) == 2
+
+    for banned_peer in node2.blacklist.ban_counter:
+        assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
+
+    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
+    node3_endpoint = replace_port(node3_endpoint, node3.port)
+    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert node3_endpoint in node1.blacklist
+
+    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
+    node2_endpoint = replace_port(node2_endpoint, node2.port)
+    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert node2_endpoint not in node1.blacklist
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
+
+    node1 = await hivemind.DHTNode.create(blacklist_time=999)
+    with pytest.raises(ValidationError):
+        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
+                                              endpoint=fake_endpoint)