Преглед изворни кода

Add DHT peer validation, add DHT.get_visible_address, add blacklist for unresponsive peers (#137)

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic пре 4 година
родитељ
комит
c36b5b1a9b

+ 1 - 1
hivemind/__init__.py

@@ -3,4 +3,4 @@ from hivemind.dht import *
 from hivemind.server import *
 from hivemind.utils import *
 
-__version__ = '0.8.22'
+__version__ = '0.8.23'

+ 18 - 19
hivemind/client/averaging/__init__.py

@@ -47,6 +47,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
       note - this expiration time only applies to looking for group, passing tensors in allreduce may take more time
     :param compression_type: optionally compress tensors with this compression algorithm before sending them to peers
     :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
+      towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
     :note: request_timeout must be smaller than averaging_expiration to avoid potential deadlocks.
     :param chunk_size_bytes: tensors for AllReduce will be divided into chunks of this size (to improve gRPC throughput)
@@ -66,8 +68,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
 
     def __init__(self, averaged_tensors: Sequence[torch.Tensor], dht: hivemind.dht.DHT, *, start: bool,
                  prefix: str, target_group_size: int, min_group_size: int = 2, initial_group_bits: Optional[str] = None,
-                 averaging_expiration: float = 15, allreduce_timeout: Optional[float] = None,
-                 request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
+                 averaging_expiration: float = 15, request_timeout: float = 3, chunk_size_bytes: int = 2 ** 16,
+                 allreduce_timeout: Optional[float] = None, averaging_alpha: float = 1.0,
                  compression_type: runtime_pb2.CompressionType = runtime_pb2.CompressionType.NONE,
                  listen_on: Endpoint = '0.0.0.0:*', receiver_threads: int = 1, daemon: bool = True,
                  channel_options: Optional[Sequence[Tuple[str, Any]]] = None, **kwargs):
@@ -95,7 +97,7 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
             prefix=prefix, initial_group_bits=initial_group_bits, target_group_size=target_group_size,
             min_group_size=min_group_size, averaging_expiration=averaging_expiration, request_timeout=request_timeout,
             chunk_size_bytes=chunk_size_bytes, compression_type=compression_type)
-        self.allreduce_timeout = allreduce_timeout
+        self.averaging_alpha, self.allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
 
         self._pipe, self.pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with a background process
@@ -193,8 +195,8 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 group_id = allreduce_group.group_id
                 self._running_groups[group_id] = allreduce_group
                 self._pending_group_assembled.set()
-                averaging_deltas = await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
-                update_ok = await loop.run_in_executor(None, lambda: self.update_tensors(averaging_deltas, add=True))
+                await asyncio.wait_for(allreduce_group.run(), self.allreduce_timeout)
+                update_ok = await loop.run_in_executor(None, self.update_tensors, allreduce_group)
 
                 # averaging is finished, exit the loop
                 future.set_result(update_ok)
@@ -213,23 +215,20 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
                 _ = self._running_groups.pop(group_id, None)
                 self._pending_group_assembled.set()
 
-    def update_tensors(self, tensors: Sequence[torch.Tensor], *, add: bool = False) -> bool:
+    def update_tensors(self, allreduce_group: AllReduceRunner) -> bool:
         """
-        Set or change the values of self.averaged_tensors.
+        a private (extendable) method that applies changes from a finished allreduce to local tensors
 
-        :param tensors: list/tuple of tensors of same shape as self.averaged_tensors
-        :param add: if True, add tensors to self.averaged_tensors in-place
-          by default, simply write the values of :tensors: to self.averaged_tensors
-        :note: if there may be updates running in background, it is recommended to use add=True
+        :return: True on success, False on failure
         """
-        assert len(tensors) == len(self._averaged_tensors)
-        with torch.no_grad(), self.lock_averaged_tensors:
-            for tensor, update in zip(self._averaged_tensors, tensors):
-                if add:
-                    tensor += update
-                else:
-                    tensor[...] = update
-        return True
+        assert allreduce_group.return_deltas and allreduce_group.future.done()
+        averaging_deltas = allreduce_group.future.result()
+
+        with torch.no_grad(), self.get_tensors() as local_tensors:
+            assert len(local_tensors) == len(self._averaged_tensors)
+            for tensor, update in zip(local_tensors, averaging_deltas):
+                tensor.add_(update, alpha=self.averaging_alpha)
+            return True
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:

+ 54 - 2
hivemind/dht/__init__.py

@@ -21,13 +21,12 @@ from collections import deque
 from concurrent.futures import ThreadPoolExecutor
 from typing import List, Tuple, Optional, Sequence, Union, Dict, Deque, NamedTuple, Iterator, Set
 
-import uvloop
 from numpy import nextafter
 
 from hivemind.client import RemoteExpert
 from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
 from hivemind.dht.routing import get_dht_time, DHTValue
-from hivemind.utils import MPFuture, Endpoint, get_logger, switch_to_uvloop
+from hivemind.utils import MPFuture, Endpoint, Hostname, get_logger, switch_to_uvloop, strip_port
 
 logger = get_logger(__name__)
 
@@ -181,6 +180,59 @@ class DHT(mp.Process):
     def port(self) -> Optional[int]:
         return self._port.value if self._port.value != 0 else None
 
+    def get_visible_address(self, num_peers: Optional[int] = None, peers: Sequence[Endpoint] = ()) -> Hostname:
+        """
+        Get this machine's visible address by requesting other peers or using pre-specified network addresses.
+        If no parameters are specified, this function will check for manual endpoint; if unavailable, ask 1 random peer.
+
+        :param num_peers: if specified, ask multiple peers and check that they perceive the same endpoint
+        :param peers: if specified, ask these exact peers instead of choosing random known peers
+        :note: if this node has no known peers in routing table, one must specify :peers: manually
+        """
+        assert num_peers is None or peers == (), "please specify either a num_peers or the list of peers, not both"
+        assert not isinstance(peers, str) and isinstance(peers, Sequence), "Please send a list / tuple of endpoints"
+        future, _future = MPFuture.make_pair()
+        self.pipe.send(('_get_visible_address', [], dict(num_peers=num_peers, peers=peers, future=_future)))
+        return future.result()
+
+    async def _get_visible_address(self, node: DHTNode, num_peers: Optional[int], peers: Sequence[Endpoint],
+                                   future: Optional[MPFuture]):
+        if not peers and (num_peers or not node.protocol.node_info.endpoint):
+            # if we can't resolve the endpoint locally, ask one random peer
+            peers_and_endpoints = node.protocol.routing_table.get_nearest_neighbors(
+                DHTID.generate(), num_peers or 1, exclude=node.node_id)
+            peers = tuple(endpoint for node_id, endpoint in peers_and_endpoints)
+
+        chosen_address = None
+        if peers:
+            possible_endpoints: Sequence[Optional[Endpoint]] = await asyncio.gather(*(
+                node.protocol.get_outgoing_request_endpoint(peer) for peer in peers))
+
+            for endpoint in possible_endpoints:
+                if endpoint is None:
+                    continue
+                address = strip_port(endpoint)
+                if chosen_address is not None and address != chosen_address:
+                    logger.warning("At least two peers returned different visible addresses for this node:"
+                                   f"{address} and {chosen_address} (keeping the former one)")
+                else:
+                    chosen_address = address
+
+            if chosen_address is None:
+                logger.warning(f"None of the selected peers responded with an address ({peers})")
+
+        if node.protocol.node_info.endpoint:
+            address = strip_port(node.protocol.node_info.endpoint)
+            if chosen_address is not None and address != chosen_address:
+                logger.warning(f"Node was manually given endpoint {address} , but other peers report {chosen_address}")
+            chosen_address = chosen_address or address
+
+        if chosen_address:
+            future.set_result(chosen_address)
+        else:
+            future.set_exception(ValueError(f"Can't get address: DHT node has no peers and no public endpoint."
+                                            f" Please ensure the node is connected or specify peers=... manually."))
+
     def declare_experts(self, uids: Sequence[ExpertUID], endpoint: Endpoint, wait: bool = True,
                         timeout: Optional[float] = None) -> Dict[ExpertUID, bool]:
         """

+ 80 - 10
hivemind/dht/node.py

@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import asyncio
 import random
-from collections import defaultdict
+from collections import defaultdict, Counter
 from dataclasses import dataclass, field
 from functools import partial
 from typing import Optional, Tuple, List, Dict, DefaultDict, Collection, Union, Set, Awaitable, Callable, Any
@@ -11,9 +11,10 @@ from sortedcontainers import SortedSet
 
 from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTExpiration, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
-from hivemind.dht.storage import CacheRefreshQueue, DictionaryDHTValue
+from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase, ValueWithExpiration
+from hivemind.utils import Endpoint, LOCALHOST, MSGPackSerializer, get_logger, SerializerBase
+from hivemind.utils.timed_storage import TimedStorage, ValueWithExpiration
 
 logger = get_logger(__name__)
 
@@ -67,6 +68,7 @@ class DHTNode:
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
+    blacklist: Blacklist
     # fmt:on
 
     @classmethod
@@ -76,7 +78,9 @@ class DHTNode:
             wait_timeout: float = 5, refresh_timeout: Optional[float] = None, bootstrap_timeout: Optional[float] = None,
             cache_locally: bool = True, cache_nearest: int = 1, cache_size=None, cache_refresh_before_expiry: float = 5,
             cache_on_store: bool = True, reuse_get_requests: bool = True, num_workers: int = 1, chunk_size: int = 16,
-            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", **kwargs) -> DHTNode:
+            blacklist_time: float = 5.0, backoff_rate: float = 2.0,
+            listen: bool = True, listen_on: Endpoint = "0.0.0.0:*", endpoint: Optional[Endpoint] = None,
+            validate: bool = True, strict: bool = True, **kwargs) -> DHTNode:
         """
         :param node_id: current node's identifier, determines which keys it will store locally, defaults to random id
         :param initial_peers: connects to these peers to populate routing table, defaults to no peers
@@ -102,9 +106,14 @@ class DHTNode:
           all concurrent get requests for the same key will reuse the procedure that is currently in progress
         :param num_workers: concurrent workers in traverse_dht (see traverse_dht num_workers param)
         :param chunk_size: maximum number of concurrent calls in get_many and cache refresh queue
+        :param blacklist_time: excludes non-responsive peers from search for this many seconds (set 0 to disable)
+        :param backoff_rate: blacklist time will be multiplied by :backoff_rate: for each successive non-response
+        :param validate: if True, use initial peers to validate that this node is accessible and synchronized
+        :param strict: if True, any error encountered in validation will interrupt the creation of DHTNode
         :param listen: if True (default), this node will accept incoming request and otherwise be a DHT "citzen"
           if False, this node will refuse any incoming request, effectively being only a "client"
         :param listen_on: network interface, e.g. "0.0.0.0:1337" or "localhost:*" (* means pick any port) or "[::]:7654"
+        :param endpoint: if specified, this is peer's preferred public endpoint. Otherwise let peers infer endpoint
         :param channel_options: options for grpc.aio.insecure_channel, e.g. [('grpc.enable_retries', 0)]
           see https://grpc.github.io/grpc/core/group__grpc__arg__keys.html for a list of all options
         :param kwargs: extra parameters used in grpc.aio.server
@@ -121,19 +130,21 @@ class DHTNode:
         self.refresh_timeout = refresh_timeout
         self.cache_locally, self.cache_nearest, self.cache_on_store = cache_locally, cache_nearest, cache_on_store
         self.cache_refresh_before_expiry = cache_refresh_before_expiry
+        self.blacklist = Blacklist(blacklist_time, backoff_rate)
         self.cache_refresh_queue = CacheRefreshQueue()
         self.cache_refresh_evt = asyncio.Event()
         self.cache_refresh_task = None
 
         self.protocol = await DHTProtocol.create(self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
-                                                 parallel_rpc, cache_size, listen, listen_on, **kwargs)
+                                                 parallel_rpc, cache_size, listen, listen_on, endpoint, **kwargs)
         self.port = self.protocol.port
 
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
             start_time = get_dht_time()
-            ping_tasks = map(self.protocol.call_ping, initial_peers)
+            ping_tasks = set(asyncio.create_task(self.protocol.call_ping(peer, validate=validate, strict=strict))
+                             for peer in initial_peers)
             finished_pings, unfinished_pings = await asyncio.wait(ping_tasks, return_when=asyncio.FIRST_COMPLETED)
 
             # stage 2: gather remaining peers (those who respond within bootstrap_timeout)
@@ -147,6 +158,10 @@ class DHTNode:
             if not finished_pings:
                 logger.warning("DHTNode bootstrap failed: none of the initial_peers responded to a ping.")
 
+            if strict:
+                for task in asyncio.as_completed(finished_pings):
+                    await task  # propagate exceptions
+
             # stage 3: traverse dht to find my own nearest neighbors and populate the routing table
             # ... maybe receive some values that we are meant to store (see protocol.update_routing_table)
             # note: using asyncio.wait instead of wait_for because wait_for cancels task on timeout
@@ -192,11 +207,11 @@ class DHTNode:
         if node_to_endpoint is None:
             node_to_endpoint: Dict[DHTID, Endpoint] = dict()
             for query in queries:
-                node_to_endpoint.update(
-                    self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id))
+                neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
+                node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
 
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
-            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
@@ -433,7 +448,7 @@ class DHTNode:
         # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             queries = list(queries)
-            response = await self.protocol.call_find(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
@@ -480,6 +495,22 @@ class DHTNode:
         else:
             pending_requests.discard(finished)
 
+    async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
+        """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
+        if endpoint in self.blacklist:
+            return None
+        response = await self.protocol.call_find(endpoint, keys)
+        if response:
+            self.blacklist.register_success(endpoint)
+            return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
+                    for key, (maybe_value, nearest_peers) in response.items()}
+        else:
+            self.blacklist.register_failure(endpoint)
+            return None
+
+    def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
+        return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
+
     def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
         if search.found_something and search.source_node_id == self.node_id:
@@ -629,3 +660,42 @@ class _SearchState:
 
     def __hash__(self):
         return hash(self.key_id)
+
+
+class Blacklist:
+    """
+    A temporary blacklist of non-responding peers with exponential backoff policy
+    :param base_time: peers are suspended for this many seconds by default
+    :param backoff_rate: suspension time increases by this factor after each successive failure
+    """
+    def __init__(self, base_time: float, backoff_rate: float, **kwargs):
+        self.base_time, self.backoff = base_time, backoff_rate
+        self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
+        self.ban_counter = Counter()
+
+    def register_failure(self, peer: Endpoint):
+        """ peer failed to respond, add him to blacklist or increase his downtime """
+        if peer not in self.banned_peers and self.base_time > 0:
+            ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
+            self.banned_peers.store(peer, self.ban_counter[peer], expiration_time=get_dht_time() + ban_duration)
+            self.ban_counter[peer] += 1
+
+    def register_success(self, peer):
+        """ peer responded successfully, remove him from blacklist and reset his ban time """
+        del self.banned_peers[peer], self.ban_counter[peer]
+
+    def __contains__(self, peer: Endpoint) -> bool:
+        return peer in self.banned_peers
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(base_time={self.base_time}, backoff={self.backoff}, " \
+               f"banned_peers={len(self.banned_peers)})"
+
+    def clear(self):
+        self.banned_peers.clear()
+        self.ban_counter.clear()
+
+
+class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
+    """ a queue of keys scheduled for refresh in future, used in DHTNode """
+    frozen = True

+ 77 - 23
hivemind/dht/protocol.py

@@ -10,7 +10,7 @@ from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpirat
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
-from hivemind.utils.grpc import GRPC_KEEPALIVE_OPTIONS
+from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 
 logger = get_logger(__name__)
 
@@ -18,7 +18,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Sequence[Tuple[str, Any]]; server: grpc.aio.Server
+    channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     # fmt:on
 
@@ -28,7 +28,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
     @classmethod
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
-            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None, listen=True, listen_on='0.0.0.0:*',
+            parallel_rpc: Optional[int] = None, cache_size: Optional[int] = None,
+            listen=True, listen_on='0.0.0.0:*', endpoint: Optional[Endpoint] = None,
             channel_options: Sequence[Tuple[str, Any]] = (), **kwargs) -> DHTProtocol:
         """
         A protocol that allows DHT nodes to request keys/neighbors from other DHT nodes.
@@ -54,10 +55,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
             self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
             dht_grpc.add_DHTServicer_to_server(self, self.server)
 
-            found_port = self.server.add_insecure_port(listen_on)
-            assert found_port != 0, f"Failed to listen to {listen_on}"
-            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=found_port)
-            self.port = found_port
+            self.port = self.server.add_insecure_port(listen_on)
+            assert self.port != 0, f"Failed to listen to {listen_on}"
+            if endpoint is not None and endpoint.endswith('*'):
+                endpoint = replace_port(endpoint, self.port)
+            self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
+                                              endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
             await self.server.start()
         else:  # not listening to incoming requests, client-only mode
             # note: use empty node_info so peers won't add you to their routing tables
@@ -83,32 +86,78 @@ class DHTProtocol(dht_grpc.DHTServicer):
         """ get a DHTStub that sends requests to a given peer """
         return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
 
-    async def call_ping(self, peer: Endpoint) -> Optional[DHTID]:
+    async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         Get peer's node id and add him to the routing table. If peer doesn't respond, return None
         :param peer: string network address, e.g. 123.123.123.123:1337 or [2a21:6с8:b192:2105]:8888
+        :param validate: if True, validates that node's endpoint is available
+        :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
         :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
 
         :return: node's DHTID, if peer responded and decided to send his node_id
         """
         try:
             async with self.rpc_semaphore:
-                peer_info = await self._get_dht_stub(peer).rpc_ping(self.node_info, timeout=self.wait_timeout)
+                ping_request = dht_pb2.PingRequest(peer=self.node_info, validate=validate)
+                time_requested = get_dht_time()
+                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
+                time_responded = get_dht_time()
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to ping {peer}: {error.code()}")
-            peer_info = None
-        responded = bool(peer_info and peer_info.node_id)
-        peer_id = DHTID.from_bytes(peer_info.node_id) if responded else None
+            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+            response = None
+        responded = bool(response and response.peer and response.peer.node_id)
+
+        if responded and validate:
+            try:
+                if self.server is not None and not response.available:
+                    raise ValidationError(f"peer {peer} couldn't access this node at {response.sender_endpoint} .")
+
+                if response.dht_time != dht_pb2.PingResponse.dht_time.DESCRIPTOR.default_value:
+                    if response.dht_time < time_requested - MAX_DHT_TIME_DISCREPANCY_SECONDS or \
+                            response.dht_time > time_responded + MAX_DHT_TIME_DISCREPANCY_SECONDS:
+                        raise ValidationError(f"local time must be within {MAX_DHT_TIME_DISCREPANCY_SECONDS} seconds "
+                                              f" of others(local: {time_requested:.5f}, peer: {response.dht_time:.5f})")
+            except ValidationError as e:
+                if strict:
+                    raise
+                else:
+                    logger.warning(repr(e))
+
+        peer_id = DHTID.from_bytes(response.peer.node_id) if responded else None
         asyncio.create_task(self.update_routing_table(peer_id, peer, responded=responded))
         return peer_id
 
-    async def rpc_ping(self, peer_info: dht_pb2.NodeInfo, context: grpc.ServicerContext):
+    async def get_outgoing_request_endpoint(self, peer: Endpoint) -> Optional[Endpoint]:
+        """ ask this peer how it perceives this node's outgoing request address """
+        try:
+            async with self.rpc_semaphore:
+                ping_request = dht_pb2.PingRequest(peer=None, validate=False)
+                response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
+                if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
+                    return response.sender_endpoint
+        except grpc.aio.AioRpcError as error:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+
+    async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
         """ Some node wants us to add it to our routing table. """
-        if peer_info.node_id and peer_info.rpc_port:
-            sender_id = DHTID.from_bytes(peer_info.node_id)
-            rpc_endpoint = replace_port(context.peer(), new_port=peer_info.rpc_port)
-            asyncio.create_task(self.update_routing_table(sender_id, rpc_endpoint))
-        return self.node_info
+        response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
+                                        dht_time=get_dht_time(), available=False)
+
+        if request.peer and request.peer.node_id and request.peer.rpc_port:
+            sender_id = DHTID.from_bytes(request.peer.node_id)
+            if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
+                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
+            else:
+                sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
+
+            response.sender_endpoint = sender_endpoint
+            if request.validate:
+                response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
+
+            asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
+                                                          responded=response.available or not request.validate))
+
+        return response
 
     async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
                          values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
@@ -161,14 +210,14 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to store at {peer}: {error.code()}")
+            logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
-            asyncio.create_task(self.rpc_ping(request.peer, context))
+            asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
         assert len(request.keys) == len(request.values) == len(request.expiration_time) == len(request.in_cache)
         response = dht_pb2.StoreResponse(store_ok=[], peer=self.node_info)
         keys = map(DHTID.from_bytes, request.keys)
@@ -225,7 +274,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
 
             return output
         except grpc.aio.AioRpcError as error:
-            logger.warning(f"DHTProtocol failed to find at {peer}: {error.code()}")
+            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
@@ -234,7 +283,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         """
         if request.peer:  # if requested, add peer to the routing table
-            asyncio.create_task(self.rpc_ping(request.peer, context))
+            asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
+
         response = dht_pb2.FindResponse(results=[], peer=self.node_info)
         for i, key_id in enumerate(map(DHTID.from_bytes, request.keys)):
             maybe_item = self.storage.get(key_id)
@@ -294,3 +344,7 @@ class DHTProtocol(dht_grpc.DHTServicer):
         else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:
                 del self.routing_table[node_id]
+
+
+class ValidationError(Exception):
+    """ This exception is thrown if DHT node didn't pass validation by other nodes. """

+ 0 - 3
hivemind/dht/storage.py

@@ -65,6 +65,3 @@ class DHTLocalStorage(TimedStorage[DHTID, Union[BinaryDHTValue, DictionaryDHTVal
             return False
 
 
-class CacheRefreshQueue(TimedStorage[DHTID, DHTExpiration]):
-    """ a queue of keys scheduled for refresh in future, used in DHTNode """
-    frozen = True

+ 14 - 1
hivemind/proto/dht.proto

@@ -5,7 +5,7 @@ syntax = "proto3";
 
 service DHT {
   // find out recipient's DHTID and possibly update its routing table
-  rpc rpc_ping(NodeInfo) returns (NodeInfo);
+  rpc rpc_ping(PingRequest) returns (PingResponse);
 
   // request a node to store one or multiple data items (key - value - expiration)
   rpc rpc_store(StoreRequest) returns (StoreResponse);
@@ -19,6 +19,19 @@ message NodeInfo {
   // if either node_id or port is absent, simply request recipient info (for client-only mode)
   bytes node_id = 1;                   // sender's own node id serialized with DHTID.to_bytes()
   int32 rpc_port = 2;                  // port to which sender listens for DHT RPCs
+  string endpoint = 3;                 // (optional) node's preferred return address
+}
+
+message PingRequest {
+  NodeInfo peer = 1;                   // (optional) sender's own node info, same behavior as in DHT.rpc_ping
+  bool validate = 2;                   // set to True if sender wants to validate that he is accessible and synchronized
+}
+
+message PingResponse {
+  NodeInfo peer = 1;                   // respondent's node id, for you to update routing table
+  string sender_endpoint = 2;          // echo sender's visible endpoint - used to infer his ip address
+  double dht_time = 3;                 // recipient's local DHT time - used to soft-synchronize peers
+  bool available = 4;                  // if validate = True, this flag asserts that the sender is available for ping
 }
 
 message StoreRequest {

+ 6 - 0
hivemind/utils/networking.py

@@ -21,6 +21,12 @@ def replace_port(endpoint: Endpoint, new_port: Port) -> Endpoint:
     return f"{endpoint[:endpoint.rindex(':')]}:{new_port}"
 
 
+def strip_port(endpoint: Endpoint) -> Hostname:
+    """ Removes port from the end of endpoint. If port is not specified, does nothing """
+    maybe_port = endpoint[endpoint.rindex(':') + 1:]
+    return endpoint[:endpoint.rindex(':')] if maybe_port.isdigit() or maybe_port == '*' else endpoint
+
+
 def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
     try:

+ 1 - 1
tests/benchmark_dht.py

@@ -34,7 +34,7 @@ def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_b
     print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
     random.shuffle(expert_uids)
 
-    print(f"Storing peers to dht in batches of {expert_batch_size}...")
+    print(f"Storing experts to dht in batches of {expert_batch_size}...")
     successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
     endpoints = []

+ 18 - 3
tests/test_dht_experts.py → tests/test_dht.py

@@ -2,10 +2,9 @@ import random
 import numpy as np
 import pytest
 import asyncio
-import multiprocessing as mp
 
 import hivemind
-from hivemind import LOCALHOST, UidEndpoint
+from hivemind import LOCALHOST, UidEndpoint, strip_port
 
 
 @pytest.mark.forked
@@ -21,7 +20,7 @@ def test_store_get_experts():
     expert_uids = [f"my_expert.{i}" for i in range(110)]
     batch_size = 10
     for batch_start in range(0, len(expert_uids), batch_size):
-        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
+        you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost:1234')
 
     found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
     assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
@@ -37,6 +36,22 @@ def test_store_get_experts():
         peer.shutdown()
 
 
+@pytest.mark.forked
+def test_dht_get_address(addr=LOCALHOST, dummy_endpoint='123.45.67.89:*'):
+    node1 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    node2 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node1.port}"])
+    node3 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", initial_peers=[f"{addr}:{node2.port}"])
+    assert addr in node3.get_visible_address(num_peers=2)
+
+    node4 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*")
+    with pytest.raises(ValueError):
+        node4.get_visible_address()
+    assert node4.get_visible_address(peers=[f'{addr}:{node1.port}']).endswith(addr)
+
+    node5 = hivemind.DHT(start=True, listen_on=f"0.0.0.0:*", endpoint=f"{dummy_endpoint}")
+    assert node5.get_visible_address() == strip_port(dummy_endpoint)
+
+
 @pytest.mark.forked
 def test_beam_search(dht_size=20, total_experts=128, batch_size=32, initial_peers=3, beam_size=4, parallel_rpc=256,
                      grid_dims=(32, 32, 32)):

+ 46 - 2
tests/test_dht_node.py

@@ -9,9 +9,9 @@ import pytest
 import hivemind
 from typing import List, Dict
 
-from hivemind import get_dht_time
+from hivemind import get_dht_time, replace_port
 from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
-from hivemind.dht.protocol import DHTProtocol
+from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
 
 
@@ -104,6 +104,8 @@ def test_dht_protocol():
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
+        assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
+
         if listen:
             loop.run_until_complete(protocol.shutdown())
         print("DHTProtocol test finished successfully!")
@@ -390,3 +392,45 @@ async def test_dhtnode_reuse_get():
     assert (await futures1['k1'])[0] == 123
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_blacklist():
+    node1 = await hivemind.DHTNode.create(blacklist_time=999)
+    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+
+    assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
+    assert len(node2.blacklist.ban_counter) == 0
+
+    await node3.shutdown()
+    await node4.shutdown()
+
+    assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
+
+    assert len(node2.blacklist.ban_counter) == 2
+
+    for banned_peer in node2.blacklist.ban_counter:
+        assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
+
+    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
+    node3_endpoint = replace_port(node3_endpoint, node3.port)
+    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert node3_endpoint in node1.blacklist
+
+    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
+    node2_endpoint = replace_port(node2_endpoint, node2.port)
+    assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
+    assert node2_endpoint not in node1.blacklist
+
+
+@pytest.mark.forked
+@pytest.mark.asyncio
+async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
+
+    node1 = await hivemind.DHTNode.create(blacklist_time=999)
+    with pytest.raises(ValidationError):
+        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
+                                              endpoint=fake_endpoint)