Pārlūkot izejas kodu

Rename Endpoint to PeerID in DHT (#313)

This PR follows #296, removes importing PeerID as Endpoint in dht/{node,protocol,routing}.py and related tests, and performs a number of replacements like Endpoint -> PeerID and endpoint -> peer_id.
Alexander Borzunov 4 gadi atpakaļ
vecāks
revīzija
4a33d1b711

+ 42 - 42
hivemind/dht/node.py

@@ -17,7 +17,7 @@ from hivemind.dht.protocol import DHTProtocol
 from hivemind.dht.routing import DHTID, DHTKey, get_dht_time, DHTValue, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.traverse import traverse_dht
-from hivemind.p2p import P2P, PeerID as Endpoint
+from hivemind.p2p import P2P, PeerID
 from hivemind.utils import MSGPackSerializer, get_logger, SerializerBase
 from hivemind.utils.auth import AuthorizerBase
 from hivemind.utils.timed_storage import DHTExpiration, TimedStorage, ValueWithExpiration
@@ -70,7 +70,7 @@ class DHTNode:
 
     """
     # fmt:off
-    node_id: DHTID; is_alive: bool; endpoint: Endpoint; num_replicas: int; num_workers: int; protocol: DHTProtocol
+    node_id: DHTID; is_alive: bool; peer_id: PeerID; num_replicas: int; num_workers: int; protocol: DHTProtocol
     chunk_size: int; refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float
     cache_on_store: bool; reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedSet[_SearchState]]
     cache_refresh_task: Optional[asyncio.Task]; cache_refresh_evt: asyncio.Event; cache_refresh_queue: CacheRefreshQueue
@@ -167,10 +167,10 @@ class DHTNode:
         self.protocol = await DHTProtocol.create(
             p2p, self.node_id, bucket_size, depth_modulo, num_replicas, wait_timeout,
             parallel_rpc, cache_size, listen, record_validator, authorizer)
-        self.endpoint = p2p.id
+        self.peer_id = p2p.id
 
         if initial_peers:
-            initial_peers = {Endpoint.from_base58(Multiaddr(item)['p2p']) for item in initial_peers}
+            initial_peers = {PeerID.from_base58(Multiaddr(item)['p2p']) for item in initial_peers}
 
             # stage 1: ping initial_peers, add each other to the routing table
             bootstrap_timeout = bootstrap_timeout if bootstrap_timeout is not None else wait_timeout
@@ -218,17 +218,17 @@ class DHTNode:
 
     async def find_nearest_nodes(
             self, queries: Collection[DHTID], k_nearest: Optional[int] = None, beam_size: Optional[int] = None,
-            num_workers: Optional[int] = None, node_to_endpoint: Optional[Dict[DHTID, Endpoint]] = None,
-            exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, Endpoint]]:
+            num_workers: Optional[int] = None, node_to_peer_id: Optional[Dict[DHTID, PeerID]] = None,
+            exclude_self: bool = False, **kwargs) -> Dict[DHTID, Dict[DHTID, PeerID]]:
         """
         :param queries: find k nearest nodes for each of these DHTIDs
         :param k_nearest: return this many nearest nodes for every query (if there are enough nodes)
         :param beam_size: replacement for self.beam_size, see traverse_dht beam_size param
         :param num_workers: replacement for self.num_workers, see traverse_dht num_workers param
-        :param node_to_endpoint: if specified, uses this dict[node_id => endpoint] as initial peers
+        :param node_to_peer_id: if specified, uses this dict[node_id => peer_id] as initial peers
         :param exclude_self: if True, nearest nodes will not contain self.node_id (default = use local peers)
         :param kwargs: additional params passed to traverse_dht
-        :returns: for every query, return nearest peers ordered dict[peer DHTID -> network Endpoint], nearest-first
+        :returns: for every query, return nearest peers ordered dict[peer DHTID -> network PeerID], nearest-first
         """
         queries = tuple(queries)
         k_nearest = k_nearest if k_nearest is not None else self.protocol.bucket_size
@@ -236,35 +236,35 @@ class DHTNode:
         beam_size = beam_size if beam_size is not None else max(self.protocol.bucket_size, k_nearest)
         if k_nearest > beam_size:
             logger.warning("Warning: beam_size is too small, beam search is not guaranteed to find enough nodes")
-        if node_to_endpoint is None:
-            node_to_endpoint: Dict[DHTID, Endpoint] = dict()
+        if node_to_peer_id is None:
+            node_to_peer_id: Dict[DHTID, PeerID] = dict()
             for query in queries:
                 neighbors = self.protocol.routing_table.get_nearest_neighbors(query, beam_size, exclude=self.node_id)
-                node_to_endpoint.update(self._filter_blacklisted(dict(neighbors)))
+                node_to_peer_id.update(self._filter_blacklisted(dict(neighbors)))
 
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
-            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_peer_id[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             for query, (_, peers) in response.items():
-                node_to_endpoint.update(peers)
+                node_to_peer_id.update(peers)
                 output[query] = tuple(peers.keys()), False  # False means "do not interrupt search"
             return output
 
         nearest_nodes_per_query, visited_nodes = await traverse_dht(
-            queries, initial_nodes=list(node_to_endpoint), beam_size=beam_size, num_workers=num_workers,
+            queries, initial_nodes=list(node_to_peer_id), beam_size=beam_size, num_workers=num_workers,
             queries_per_call=int(len(queries) ** 0.5), get_neighbors=get_neighbors,
             visited_nodes={query: {self.node_id} for query in queries}, **kwargs)
 
-        nearest_nodes_with_endpoints = {}
+        nearest_nodes_with_peer_ids = {}
         for query, nearest_nodes in nearest_nodes_per_query.items():
             if not exclude_self:
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
-                node_to_endpoint[self.node_id] = self.endpoint
-            nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
-        return nearest_nodes_with_endpoints
+                node_to_peer_id[self.node_id] = self.peer_id
+            nearest_nodes_with_peer_ids[query] = {node: node_to_peer_id[node] for node in nearest_nodes[:k_nearest]}
+        return nearest_nodes_with_peer_ids
 
     async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration,
                     subkey: Optional[Subkey] = None, **kwargs) -> bool:
@@ -310,10 +310,10 @@ class DHTNode:
         store_ok = {(key, subkey): None for key, subkey in zip(keys, subkeys)}  # outputs, updated during search
         store_finished_events = {(key, subkey): asyncio.Event() for key, subkey in zip(keys, subkeys)}
 
-        # pre-populate node_to_endpoint
-        node_to_endpoint: Dict[DHTID, Endpoint] = dict()
+        # pre-populate node_to_peer_id
+        node_to_peer_id: Dict[DHTID, PeerID] = dict()
         for key_id in unfinished_key_ids:
-            node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
+            node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
         async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set[DHTID]) -> None:
@@ -361,7 +361,7 @@ class DHTNode:
                                 store_finished_events[original_key, subkey].set()
                     else:
                         pending_store_tasks.add(asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], keys=[key_id] * len(current_values), values=binary_values,
+                            node_to_peer_id[node_id], keys=[key_id] * len(current_values), values=binary_values,
                             expiration_time=current_expirations, subkeys=current_subkeys)))
 
                 # await nearest task. If it fails, dispatch more on the next iteration
@@ -384,7 +384,7 @@ class DHTNode:
                 store_finished_events[original_key, subkey].set()
 
         store_task = asyncio.create_task(self.find_nearest_nodes(
-            queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_endpoint=node_to_endpoint,
+            queries=set(unfinished_key_ids), k_nearest=self.num_replicas, node_to_peer_id=node_to_peer_id,
             found_callback=on_found, exclude_self=exclude_self, **kwargs))
         try:
             await asyncio.gather(store_task, *(evt.wait() for evt in store_finished_events.values()))
@@ -496,21 +496,21 @@ class DHTNode:
 
         # stage 2: traverse the DHT to get the remaining keys from remote peers
         unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished]
-        node_to_endpoint: Dict[DHTID, Endpoint] = dict()  # global routing table for all keys
+        node_to_peer_id: Dict[DHTID, PeerID] = dict()  # global routing table for all keys
         for key_id in unfinished_key_ids:
-            node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors(
+            node_to_peer_id.update(self.protocol.routing_table.get_nearest_neighbors(
                 key_id, self.protocol.bucket_size, exclude=self.node_id))
 
         # V-- this function will be called every time traverse_dht decides to request neighbors from a remote peer
         async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, Tuple[Tuple[DHTID], bool]]:
             queries = list(queries)
-            response = await self._call_find_with_blacklist(node_to_endpoint[peer], queries)
+            response = await self._call_find_with_blacklist(node_to_peer_id[peer], queries)
             if not response:
                 return {query: ([], False) for query in queries}
 
             output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {}
             for key_id, (maybe_value_with_expiration, peers) in response.items():
-                node_to_endpoint.update(peers)
+                node_to_peer_id.update(peers)
                 search_results[key_id].add_candidate(maybe_value_with_expiration, source_node_id=peer)
                 output[key_id] = tuple(peers.keys()), search_results[key_id].finished
                 # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user)
@@ -519,10 +519,10 @@ class DHTNode:
         # V-- this function will be called exactly once when traverse_dht finishes search for a given key
         async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]):
             search_results[key_id].finish_search()  # finish search whether or we found something
-            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint, _is_refresh=_is_refresh)
+            self._cache_new_result(search_results[key_id], nearest_nodes, node_to_peer_id, _is_refresh=_is_refresh)
 
         asyncio.create_task(traverse_dht(
-            queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), beam_size=beam_size,
+            queries=list(unfinished_key_ids), initial_nodes=list(node_to_peer_id), beam_size=beam_size,
             num_workers=num_workers, queries_per_call=min(int(len(unfinished_key_ids) ** 0.5), self.chunk_size),
             get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids},
             found_callback=found_callback, await_all_tasks=False))
@@ -551,21 +551,21 @@ class DHTNode:
         else:
             pending_requests.discard(finished)
 
-    async def _call_find_with_blacklist(self, endpoint: Endpoint, keys: Collection[DHTID]):
-        """ same as call_find, but skip if :endpoint: is blacklisted; also exclude blacklisted neighbors from result """
-        if endpoint in self.blacklist:
+    async def _call_find_with_blacklist(self, peer_id: PeerID, keys: Collection[DHTID]):
+        """ same as call_find, but skip if :peer_id: is blacklisted; also exclude blacklisted neighbors from result """
+        if peer_id in self.blacklist:
             return None
-        response = await self.protocol.call_find(endpoint, keys)
+        response = await self.protocol.call_find(peer_id, keys)
         if response:
-            self.blacklist.register_success(endpoint)
+            self.blacklist.register_success(peer_id)
             return {key: (maybe_value, self._filter_blacklisted(nearest_peers))
                     for key, (maybe_value, nearest_peers) in response.items()}
         else:
-            self.blacklist.register_failure(endpoint)
+            self.blacklist.register_failure(peer_id)
             return None
 
-    def _filter_blacklisted(self, peer_endpoints: Dict[DHTID, Endpoint]):
-        return {peer: endpoint for peer, endpoint in peer_endpoints.items() if endpoint not in self.blacklist}
+    def _filter_blacklisted(self, peer_ids: Dict[DHTID, PeerID]):
+        return {peer: peer_id for peer, peer_id in peer_ids.items() if peer_id not in self.blacklist}
 
     def _trigger_cache_refresh(self, search: _SearchState):
         """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """
@@ -620,7 +620,7 @@ class DHTNode:
                 await self.get_many_by_id(keys_to_refresh, sufficient_expiration_time, _is_refresh=True)
 
     def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID],
-                          node_to_endpoint: Dict[DHTID, Endpoint], _is_refresh: bool = False):
+                          node_to_peer_id: Dict[DHTID, PeerID], _is_refresh: bool = False):
         """ after key_id is found, update cache according to caching policy. used internally in get and get_many """
         if search.found_something:
             _, storage_expiration_time = self.protocol.storage.get(search.key_id) or (None, -float('inf'))
@@ -635,7 +635,7 @@ class DHTNode:
                         if node_id == search.source_node_id:
                             continue
                         asyncio.create_task(self.protocol.call_store(
-                            node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
+                            node_to_peer_id[node_id], [search.key_id], [search.binary_value], [search.expiration_time],
                             in_cache=True))
                         num_cached_nodes += 1
                         if num_cached_nodes >= self.cache_nearest:
@@ -746,10 +746,10 @@ class Blacklist:
 
     def __init__(self, base_time: float, backoff_rate: float, **kwargs):
         self.base_time, self.backoff = base_time, backoff_rate
-        self.banned_peers = TimedStorage[Endpoint, int](**kwargs)
+        self.banned_peers = TimedStorage[PeerID, int](**kwargs)
         self.ban_counter = Counter()
 
-    def register_failure(self, peer: Endpoint):
+    def register_failure(self, peer: PeerID):
         """ peer failed to respond, add him to blacklist or increase his downtime """
         if peer not in self.banned_peers and self.base_time > 0:
             ban_duration = self.base_time * self.backoff ** self.ban_counter[peer]
@@ -760,7 +760,7 @@ class Blacklist:
         """ peer responded successfully, remove him from blacklist and reset his ban time """
         del self.banned_peers[peer], self.ban_counter[peer]
 
-    def __contains__(self, peer: Endpoint) -> bool:
+    def __contains__(self, peer: PeerID) -> bool:
         return peer in self.banned_peers
 
     def __repr__(self):

+ 24 - 26
hivemind/dht/protocol.py

@@ -2,14 +2,12 @@
 from __future__ import annotations
 
 import asyncio
-from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
-
-import grpc
+from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
 
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
-from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, Servicer
+from hivemind.p2p import P2P, P2PContext, PeerID, Servicer
 from hivemind.proto import dht_pb2
 from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
@@ -44,7 +42,7 @@ class DHTProtocol(Servicer):
         See DHTNode (node.py) for a more detailed description.
 
         :note: the rpc_* methods defined in this class will be automatically exposed to other DHT nodes,
-         for instance, def rpc_ping can be called as protocol.call_ping(endpoint, dht_id) from a remote machine
+         for instance, def rpc_ping can be called as protocol.call_ping(peer_id, dht_id) from a remote machine
          Only the call_* methods are meant to be called publicly, e.g. from DHTNode
          Read more: https://github.com/bmuller/rpcudp/tree/master/rpcudp
         """
@@ -73,16 +71,16 @@ class DHTProtocol(Servicer):
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
 
-    def get_stub(self, peer: Endpoint) -> AuthRPCWrapper:
+    def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
         """ get a stub that sends requests to a given peer """
         stub = super().get_stub(self.p2p, peer)
         return AuthRPCWrapper(stub, AuthRole.CLIENT, self.authorizer, service_public_key=None)
 
-    async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
+    async def call_ping(self, peer: PeerID, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         Get peer's node id and add him to the routing table. If peer doesn't respond, return None
         :param peer: peer ID to ping
-        :param validate: if True, validates that node's endpoint is available
+        :param validate: if True, validates that node's peer_id is available
         :param strict: if strict=True, validation will raise exception on fail, otherwise it will only warn
         :note: if DHTProtocol was created with listen=True, also request peer to add you to his routing table
 
@@ -128,17 +126,17 @@ class DHTProtocol(Servicer):
 
         if request.peer and request.peer.node_id:
             sender_id = DHTID.from_bytes(request.peer.node_id)
-            sender_endpoint = context.remote_id
+            sender_peer_id = context.remote_id
 
             if request.validate:
-                response.available = await self.call_ping(sender_endpoint, validate=False) == sender_id
+                response.available = await self.call_ping(sender_peer_id, validate=False) == sender_id
 
-            asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
+            asyncio.create_task(self.update_routing_table(sender_id, sender_peer_id,
                                                           responded=response.available or not request.validate))
 
         return response
 
-    async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
+    async def call_store(self, peer: PeerID, keys: Sequence[DHTID],
                          values: Sequence[Union[BinaryDHTValue, DictionaryDHTValue]],
                          expiration_time: Union[DHTExpiration, Sequence[DHTExpiration]],
                          subkeys: Optional[Union[Subkey, Sequence[Optional[Subkey]]]] = None,
@@ -190,7 +188,7 @@ class DHTProtocol(Servicer):
             return response.store_ok
         except Exception as e:
             logger.debug(f"DHTProtocol failed to store at {peer}", exc_info=True)
-            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(peer_id=peer), peer, responded=False))
             return None
 
     async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
@@ -226,8 +224,8 @@ class DHTProtocol(Servicer):
                 response.store_ok.append(storage.store_subkey(key_id, subkey, value_bytes, expiration_time))
         return response
 
-    async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[Dict[
-        DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, Endpoint]]]]:
+    async def call_find(self, peer: PeerID, keys: Collection[DHTID]) -> Optional[Dict[
+        DHTID, Tuple[Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]], Dict[DHTID, PeerID]]]]:
         """
         Request keys from a peer. For each key, look for its (value, expiration time) locally and
          k additional peers that are most likely to have this key (ranked by XOR distance)
@@ -235,7 +233,7 @@ class DHTProtocol(Servicer):
         :returns: A dict key => Tuple[optional value, optional expiration time, nearest neighbors]
          value: value stored by the recipient with that key, or None if peer doesn't have this value
          expiration time: expiration time of the returned value, None if no value was found
-         neighbors: a dictionary[node_id : endpoint] containing nearest neighbors from peer's routing table
+         neighbors: a dictionary[node_id : peer_id] containing nearest neighbors from peer's routing table
          If peer didn't respond, returns None
         """
         keys = list(keys)
@@ -252,7 +250,7 @@ class DHTProtocol(Servicer):
             for key, result in zip(keys, response.results):
                 key_bytes = DHTID.to_bytes(key)
                 nearest = dict(zip(map(DHTID.from_bytes, result.nearest_node_ids),
-                                   map(Endpoint.from_base58, result.nearest_endpoints)))
+                                   map(PeerID.from_base58, result.nearest_peer_ids)))
 
                 if result.type == dht_pb2.NOT_FOUND:
                     output[key] = None, nearest
@@ -276,7 +274,7 @@ class DHTProtocol(Servicer):
             return output
         except Exception as e:
             logger.debug(f"DHTProtocol failed to find at {peer}", exc_info=True)
-            asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
+            asyncio.create_task(self.update_routing_table(self.routing_table.get(peer_id=peer), peer, responded=False))
 
     async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
         """
@@ -303,23 +301,23 @@ class DHTProtocol(Servicer):
                 item = dht_pb2.FindResult(type=dht_pb2.FOUND_REGULAR, value=maybe_item.value,
                                           expiration_time=maybe_item.expiration_time)
 
-            for node_id, endpoint in self.routing_table.get_nearest_neighbors(
+            for node_id, peer_id in self.routing_table.get_nearest_neighbors(
                     key_id, k=self.bucket_size, exclude=DHTID.from_bytes(request.peer.node_id)):
                 item.nearest_node_ids.append(node_id.to_bytes())
-                item.nearest_endpoints.append(endpoint.to_base58())
+                item.nearest_peer_ids.append(peer_id.to_base58())
             response.results.append(item)
         return response
 
-    async def update_routing_table(self, node_id: Optional[DHTID], peer_endpoint: Endpoint, responded=True):
+    async def update_routing_table(self, node_id: Optional[DHTID], peer_id: PeerID, responded=True):
         """
         This method is called on every incoming AND outgoing request to update the routing table
 
-        :param peer_endpoint: sender endpoint for incoming requests, recipient endpoint for outgoing requests
+        :param peer_id: sender peer_id for incoming requests, recipient peer_id for outgoing requests
         :param node_id: sender node id for incoming requests, recipient node id for outgoing requests
         :param responded: for outgoing requests, this indicated whether recipient responded or not.
           For incoming requests, this should always be True
         """
-        node_id = node_id if node_id is not None else self.routing_table.get(endpoint=peer_endpoint)
+        node_id = node_id if node_id is not None else self.routing_table.get(peer_id=peer_id)
         if responded:  # incoming request or outgoing request with response
             if node_id not in self.routing_table:
                 # we just met a new node, maybe we know some values that it *should* store
@@ -334,13 +332,13 @@ class DHTProtocol(Servicer):
                     if not neighbors or (new_node_should_store and this_node_is_responsible):
                         data_to_send.append((key, item.value, item.expiration_time))
                 if data_to_send:
-                    asyncio.create_task(self.call_store(peer_endpoint, *zip(*data_to_send), in_cache=False))
+                    asyncio.create_task(self.call_store(peer_id, *zip(*data_to_send), in_cache=False))
 
-            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_endpoint)
+            maybe_node_to_ping = self.routing_table.add_or_update_node(node_id, peer_id)
             if maybe_node_to_ping is not None:
                 # we couldn't add new node because the table was full. Check if existing peers are alive (Section 2.2)
                 # ping one least-recently updated peer: if it won't respond, remove it from the table, else update it
-                asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's endpoint
+                asyncio.create_task(self.call_ping(maybe_node_to_ping[1]))  # [1]-th element is that node's peer_id
 
         else:  # we sent outgoing request and peer did not respond
             if node_id is not None and node_id in self.routing_table:

+ 58 - 58
hivemind/dht/routing.py

@@ -8,7 +8,7 @@ import random
 from collections.abc import Iterable
 from itertools import chain
 from typing import Tuple, Optional, List, Dict, Set, Union, Any, Sequence
-from hivemind.p2p import PeerID as Endpoint
+from hivemind.p2p import PeerID
 from hivemind.utils import MSGPackSerializer, get_dht_time
 
 DHTKey, Subkey, DHTValue, BinaryDHTID, BinaryDHTValue, = Any, Any, Any, bytes, bytes
@@ -28,8 +28,8 @@ class RoutingTable:
     def __init__(self, node_id: DHTID, bucket_size: int, depth_modulo: int):
         self.node_id, self.bucket_size, self.depth_modulo = node_id, bucket_size, depth_modulo
         self.buckets = [KBucket(node_id.MIN, node_id.MAX, bucket_size)]
-        self.endpoint_to_uid: Dict[Endpoint, DHTID] = dict()  # all nodes currently in buckets, including replacements
-        self.uid_to_endpoint: Dict[DHTID, Endpoint] = dict()  # all nodes currently in buckets, including replacements
+        self.peer_id_to_uid: Dict[PeerID, DHTID] = dict()  # all nodes currently in buckets, including replacements
+        self.uid_to_peer_id: Dict[DHTID, PeerID] = dict()  # all nodes currently in buckets, including replacements
 
     def get_bucket_index(self, node_id: DHTID) -> int:
         """ Get the index of the bucket that the given node would fall into. """
@@ -43,9 +43,9 @@ class RoutingTable:
         assert upper_index - lower_index == 1
         return lower_index
 
-    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> Optional[Tuple[DHTID, Endpoint]]:
+    def add_or_update_node(self, node_id: DHTID, peer_id: PeerID) -> Optional[Tuple[DHTID, PeerID]]:
         """
-        Update routing table after an incoming request from :endpoint: or outgoing request to :endpoint:
+        Update routing table after an incoming request from :peer_id: or outgoing request to :peer_id:
 
         :returns: If we cannot add node_id to the routing table, return the least-recently-updated node (Section 2.2)
         :note: DHTProtocol calls this method for every incoming and outgoing request if there was a response.
@@ -54,19 +54,19 @@ class RoutingTable:
         """
         bucket_index = self.get_bucket_index(node_id)
         bucket = self.buckets[bucket_index]
-        store_success = bucket.add_or_update_node(node_id, endpoint)
+        store_success = bucket.add_or_update_node(node_id, peer_id)
 
-        if node_id in bucket.nodes_to_endpoint or node_id in bucket.replacement_nodes:
+        if node_id in bucket.nodes_to_peer_id or node_id in bucket.replacement_nodes:
             # if we added node to bucket or as a replacement, throw it into lookup dicts as well
-            self.uid_to_endpoint[node_id] = endpoint
-            self.endpoint_to_uid[endpoint] = node_id
+            self.uid_to_peer_id[node_id] = peer_id
+            self.peer_id_to_uid[peer_id] = node_id
 
         if not store_success:
             # Per section 4.2 of paper, split if the bucket has node's own id in its range
             # or if bucket depth is not congruent to 0 mod $b$
             if bucket.has_in_range(self.node_id) or bucket.depth % self.depth_modulo != 0:
                 self.split_bucket(bucket_index)
-                return self.add_or_update_node(node_id, endpoint)
+                return self.add_or_update_node(node_id, peer_id)
 
             # The bucket is full and won't split further. Return a node to ping (see this method's docstring)
             return bucket.request_ping_node()
@@ -77,48 +77,48 @@ class RoutingTable:
         self.buckets[index] = first
         self.buckets.insert(index + 1, second)
 
-    def get(self, *, node_id: Optional[DHTID] = None, endpoint: Optional[Endpoint] = None, default=None):
-        """ Find endpoint for a given DHTID or vice versa """
-        assert (node_id is None) != (endpoint is None), "Please specify either node_id or endpoint, but not both"
+    def get(self, *, node_id: Optional[DHTID] = None, peer_id: Optional[PeerID] = None, default=None):
+        """ Find peer_id for a given DHTID or vice versa """
+        assert (node_id is None) != (peer_id is None), "Please specify either node_id or peer_id, but not both"
         if node_id is not None:
-            return self.uid_to_endpoint.get(node_id, default)
+            return self.uid_to_peer_id.get(node_id, default)
         else:
-            return self.endpoint_to_uid.get(endpoint, default)
+            return self.peer_id_to_uid.get(peer_id, default)
 
-    def __getitem__(self, item: Union[DHTID, Endpoint]) -> Union[Endpoint, DHTID]:
-        """ Find endpoint for a given DHTID or vice versa """
-        return self.uid_to_endpoint[item] if isinstance(item, DHTID) else self.endpoint_to_uid[item]
+    def __getitem__(self, item: Union[DHTID, PeerID]) -> Union[PeerID, DHTID]:
+        """ Find peer_id for a given DHTID or vice versa """
+        return self.uid_to_peer_id[item] if isinstance(item, DHTID) else self.peer_id_to_uid[item]
 
-    def __setitem__(self, node_id: DHTID, endpoint: Endpoint) -> NotImplementedError:
+    def __setitem__(self, node_id: DHTID, peer_id: PeerID) -> NotImplementedError:
         raise NotImplementedError("RoutingTable doesn't support direct item assignment. Use table.try_add_node instead")
 
-    def __contains__(self, item: Union[DHTID, Endpoint]) -> bool:
-        return (item in self.uid_to_endpoint) if isinstance(item, DHTID) else (item in self.endpoint_to_uid)
+    def __contains__(self, item: Union[DHTID, PeerID]) -> bool:
+        return (item in self.uid_to_peer_id) if isinstance(item, DHTID) else (item in self.peer_id_to_uid)
 
     def __delitem__(self, node_id: DHTID):
         del self.buckets[self.get_bucket_index(node_id)][node_id]
-        node_endpoint = self.uid_to_endpoint.pop(node_id)
-        if self.endpoint_to_uid.get(node_endpoint) == node_id:
-            del self.endpoint_to_uid[node_endpoint]
+        node_peer_id = self.uid_to_peer_id.pop(node_id)
+        if self.peer_id_to_uid.get(node_peer_id) == node_id:
+            del self.peer_id_to_uid[node_peer_id]
 
     def get_nearest_neighbors(
-            self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, Endpoint]]:
+            self, query_id: DHTID, k: int, exclude: Optional[DHTID] = None) -> List[Tuple[DHTID, PeerID]]:
         """
         Find k nearest neighbors from routing table according to XOR distance, does NOT include self.node_id
 
         :param query_id: find neighbors of this node
         :param k: find this many neighbors. If there aren't enough nodes in the table, returns all nodes
         :param exclude: if True, results will not contain query_node_id even if it is in table
-        :return: a list of tuples (node_id, endpoint) for up to k neighbors sorted from nearest to farthest
+        :return: a list of tuples (node_id, peer_id) for up to k neighbors sorted from nearest to farthest
         """
         # algorithm: first add up all buckets that can contain one of k nearest nodes, then heap-sort to find best
-        candidates: List[Tuple[int, DHTID, Endpoint]] = []  # min-heap based on xor distance to query_id
+        candidates: List[Tuple[int, DHTID, PeerID]] = []  # min-heap based on xor distance to query_id
 
         # step 1: add current bucket to the candidates heap
         nearest_index = self.get_bucket_index(query_id)
         nearest_bucket = self.buckets[nearest_index]
-        for node_id, endpoint in nearest_bucket.nodes_to_endpoint.items():
-            heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+        for node_id, peer_id in nearest_bucket.nodes_to_peer_id.items():
+            heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
 
         # step 2: add adjacent buckets by ascending code tree one level at a time until you have enough nodes
         left_index, right_index = nearest_index, nearest_index + 1  # bucket indices considered, [left, right)
@@ -132,8 +132,8 @@ class RoutingTable:
             if split_direction == 0:  # leaf was split on the left, merge its right peer(s)
                 current_upper += current_upper - current_lower
                 while right_index < len(self.buckets) and self.buckets[right_index].upper <= current_upper:
-                    for node_id, endpoint in self.buckets[right_index].nodes_to_endpoint.items():
-                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                    for node_id, peer_id in self.buckets[right_index].nodes_to_peer_id.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
                     right_index += 1  # note: we may need to add more than one bucket if they are on a lower depth level
                 assert self.buckets[right_index - 1].upper == current_upper
 
@@ -141,13 +141,13 @@ class RoutingTable:
                 current_lower -= current_upper - current_lower
                 while left_index > 0 and self.buckets[left_index - 1].lower >= current_lower:
                     left_index -= 1  # note: we may need to add more than one bucket if they are on a lower depth level
-                    for node_id, endpoint in self.buckets[left_index].nodes_to_endpoint.items():
-                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, endpoint))
+                    for node_id, peer_id in self.buckets[left_index].nodes_to_peer_id.items():
+                        heapq.heappush(candidates, (query_id.xor_distance(node_id), node_id, peer_id))
                 assert self.buckets[left_index].lower == current_lower
 
         # step 3: select k nearest vertices from candidates heap
-        heap_top: List[Tuple[int, DHTID, Endpoint]] = heapq.nsmallest(k + int(exclude is not None), candidates)
-        return [(node, endpoint) for _, node, endpoint in heap_top if node != exclude][:k]
+        heap_top: List[Tuple[int, DHTID, PeerID]] = heapq.nsmallest(k + int(exclude is not None), candidates)
+        return [(node, peer_id) for _, node, peer_id in heap_top if node != exclude][:k]
 
     def __repr__(self):
         bucket_info = "\n".join(repr(bucket) for bucket in self.buckets)
@@ -158,14 +158,14 @@ class RoutingTable:
 class KBucket:
     """
     A bucket containing up to :size: of DHTIDs in [lower, upper) semi-interval.
-    Maps DHT node ids to their endpoints
+    Maps DHT node ids to their peer_ids
     """
 
     def __init__(self, lower: int, upper: int, size: int, depth: int = 0):
         assert upper - lower == 2 ** (DHTID.HASH_NBYTES * 8 - depth)
         self.lower, self.upper, self.size, self.depth = lower, upper, size, depth
-        self.nodes_to_endpoint: Dict[DHTID, Endpoint] = {}
-        self.replacement_nodes: Dict[DHTID, Endpoint] = {}
+        self.nodes_to_peer_id: Dict[DHTID, PeerID] = {}
+        self.replacement_nodes: Dict[DHTID, PeerID] = {}
         self.nodes_requested_for_ping: Set[DHTID] = set()
         self.last_updated = get_dht_time()
 
@@ -173,53 +173,53 @@ class KBucket:
         """ Check if node_id is between this bucket's lower and upper bounds """
         return self.lower <= node_id < self.upper
 
-    def add_or_update_node(self, node_id: DHTID, endpoint: Endpoint) -> bool:
+    def add_or_update_node(self, node_id: DHTID, peer_id: PeerID) -> bool:
         """
         Add node to KBucket or update existing node, return True if successful, False if the bucket is full.
         If the bucket is full, keep track of node in a replacement list, per section 4.1 of the paper.
 
         :param node_id: dht node identifier that should be added or moved to the front of bucket
-        :param endpoint: network address associated with that node id
+        :param peer_id: network address associated with that node id
         :note: this function has a side-effect of resetting KBucket.last_updated time
         """
         if node_id in self.nodes_requested_for_ping:
             self.nodes_requested_for_ping.remove(node_id)
         self.last_updated = get_dht_time()
-        if node_id in self.nodes_to_endpoint:
-            del self.nodes_to_endpoint[node_id]
-            self.nodes_to_endpoint[node_id] = endpoint
-        elif len(self.nodes_to_endpoint) < self.size:
-            self.nodes_to_endpoint[node_id] = endpoint
+        if node_id in self.nodes_to_peer_id:
+            del self.nodes_to_peer_id[node_id]
+            self.nodes_to_peer_id[node_id] = peer_id
+        elif len(self.nodes_to_peer_id) < self.size:
+            self.nodes_to_peer_id[node_id] = peer_id
         else:
             if node_id in self.replacement_nodes:
                 del self.replacement_nodes[node_id]
-            self.replacement_nodes[node_id] = endpoint
+            self.replacement_nodes[node_id] = peer_id
             return False
         return True
 
-    def request_ping_node(self) -> Optional[Tuple[DHTID, Endpoint]]:
+    def request_ping_node(self) -> Optional[Tuple[DHTID, PeerID]]:
         """ :returns: least-recently updated node that isn't already being pinged right now -- if such node exists """
-        for uid, endpoint in self.nodes_to_endpoint.items():
+        for uid, peer_id in self.nodes_to_peer_id.items():
             if uid not in self.nodes_requested_for_ping:
                 self.nodes_requested_for_ping.add(uid)
-                return uid, endpoint
+                return uid, peer_id
 
-    def __getitem__(self, node_id: DHTID) -> Endpoint:
-        return self.nodes_to_endpoint[node_id] if node_id in self.nodes_to_endpoint else self.replacement_nodes[node_id]
+    def __getitem__(self, node_id: DHTID) -> PeerID:
+        return self.nodes_to_peer_id[node_id] if node_id in self.nodes_to_peer_id else self.replacement_nodes[node_id]
 
     def __delitem__(self, node_id: DHTID):
-        if not (node_id in self.nodes_to_endpoint or node_id in self.replacement_nodes):
+        if not (node_id in self.nodes_to_peer_id or node_id in self.replacement_nodes):
             raise KeyError(f"KBucket does not contain node id={node_id}.")
 
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]
 
-        if node_id in self.nodes_to_endpoint:
-            del self.nodes_to_endpoint[node_id]
+        if node_id in self.nodes_to_peer_id:
+            del self.nodes_to_peer_id[node_id]
 
             if self.replacement_nodes:
                 newnode_id, newnode = self.replacement_nodes.popitem()
-                self.nodes_to_endpoint[newnode_id] = newnode
+                self.nodes_to_peer_id[newnode_id] = newnode
 
     def split(self) -> Tuple[KBucket, KBucket]:
         """ Split bucket over midpoint, rounded down, assign nodes to according to their id """
@@ -227,13 +227,13 @@ class KBucket:
         assert self.lower < midpoint < self.upper, f"Bucket to small to be split: [{self.lower}: {self.upper})"
         left = KBucket(self.lower, midpoint, self.size, depth=self.depth + 1)
         right = KBucket(midpoint, self.upper, self.size, depth=self.depth + 1)
-        for node_id, endpoint in chain(self.nodes_to_endpoint.items(), self.replacement_nodes.items()):
+        for node_id, peer_id in chain(self.nodes_to_peer_id.items(), self.replacement_nodes.items()):
             bucket = left if int(node_id) <= midpoint else right
-            bucket.add_or_update_node(node_id, endpoint)
+            bucket.add_or_update_node(node_id, peer_id)
         return left, right
 
     def __repr__(self):
-        return f"{self.__class__.__name__}({len(self.nodes_to_endpoint)} nodes" \
+        return f"{self.__class__.__name__}({len(self.nodes_to_peer_id)} nodes" \
                f" with {len(self.replacement_nodes)} replacements, depth={self.depth}, max size={self.size}" \
                f" lower={hex(self.lower)}, upper={hex(self.upper)})"
 

+ 2 - 2
hivemind/proto/dht.proto

@@ -64,9 +64,9 @@ message FindResult {
   bytes value = 2;                     // n/a  | serialized value | serialized DictionaryDHTValue with serialized fields
   double expiration_time = 3;          // n/a  | expiration time  | DictionaryDHTValue.latest_expiration_time
 
-  // two aligned arrays: DHTIDs and Endpoints for nearest peers (sorted by XOR distance)
+  // two aligned arrays: DHTIDs and PeerIDs for nearest peers (sorted by XOR distance)
   repeated bytes nearest_node_ids = 4;      // DHTIDs of the nearest peers serialized with node_id.to_bytes()
-  repeated string nearest_endpoints = 5;    // Base58-serialized libp2p PeerIDs of the nearest peers
+  repeated string nearest_peer_ids = 5;     // Base58-serialized libp2p PeerIDs of the nearest peers
 }
 
 message FindResponse {

+ 30 - 30
tests/test_dht_node.py

@@ -39,8 +39,8 @@ def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
 
     logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
 
-    for endpoint in maddrs_to_peer_ids(initial_peers):
-        loop.run_until_complete(protocol.call_ping(endpoint))
+    for peer_id in maddrs_to_peer_ids(initial_peers):
+        loop.run_until_complete(protocol.call_ping(peer_id))
 
     maddr_conn.send((p2p.id, visible_maddrs))
 
@@ -70,8 +70,8 @@ def launch_protocol_listener(initial_peers: Sequence[Multiaddr] = ()) -> \
 
 @pytest.mark.forked
 def test_dht_protocol():
-    peer1_id, peer1_proc, peer1_endpoint, peer1_maddrs = launch_protocol_listener()
-    peer2_id, peer2_proc, peer2_endpoint, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
+    peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
+    peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
 
     loop = asyncio.get_event_loop()
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
@@ -80,21 +80,21 @@ def test_dht_protocol():
             p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
         logger.info(f"Self id={protocol.node_id}")
 
-        assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
+        assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         store_ok = loop.run_until_complete(protocol.call_store(
-            peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )
         assert all(store_ok), "DHT rejected a trivial store"
 
         # peer 1 must know about peer 2
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_endpoint, [key]))[key]
+            protocol.call_find(peer1_id, [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
-        (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
-            f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
+        (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
+        assert recv_id == peer2_node_id and recv_peer_id == peer2_id, \
+            f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
 
         assert recv_value == value and recv_expiration == expiration, \
             f"call_find_value expected {value} (expires by {expiration}) " \
@@ -103,11 +103,11 @@ def test_dht_protocol():
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         dummy_key = DHTID.generate()
         empty_item, nodes_found_2 = loop.run_until_complete(
-            protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
+            protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
         assert empty_item is None, "Non-existent keys shouldn't have values"
-        (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
-            f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
+        (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
+        assert recv_id == peer1_node_id and recv_peer_id == peer1_id, \
+            f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
 
         # cause a non-response by querying a nonexistent peer
         assert loop.run_until_complete(protocol.call_find(PeerID.from_base58('fakeid'), [key])) is None
@@ -116,15 +116,15 @@ def test_dht_protocol():
         nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
         value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
         assert loop.run_until_complete(protocol.call_store(
-            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            peer1_id, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
             expiration_time=[expiration], subkeys=[subkey1])
         )
         assert loop.run_until_complete(protocol.call_store(
-            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            peer1_id, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
             expiration_time=[expiration + 5], subkeys=[subkey2])
         )
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
+            protocol.call_find(peer1_id, [nested_key]))[nested_key]
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
@@ -140,7 +140,7 @@ def test_dht_protocol():
 @pytest.mark.forked
 def test_empty_table():
     """ Test RPC methods with empty routing table """
-    peer_id, peer_proc, peer_endpoint, peer_maddrs = launch_protocol_listener()
+    peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
 
     loop = asyncio.get_event_loop()
     p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
@@ -150,19 +150,19 @@ def test_empty_table():
     key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
     empty_item, nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_endpoint, [key]))[key]
+        protocol.call_find(peer_peer_id, [key]))[key]
     assert empty_item is None and len(nodes_found) == 0
     assert all(loop.run_until_complete(protocol.call_store(
-        peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
     )), "peer rejected store"
 
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(peer_endpoint, [key]))[key]
+        protocol.call_find(peer_peer_id, [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert recv_value == value and recv_expiration == expiration
 
-    assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
     assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58('fakeid'))) is None
     peer_proc.terminate()
 
@@ -181,15 +181,15 @@ def test_dht_node():
 
     # test 1: find self
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-    assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
+    assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
 
     # test 2: find others
     for _ in range(10):
-        ref_endpoint, query_id = random.choice(list(dht.items()))
+        ref_peer_id, query_id = random.choice(list(dht.items()))
         nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
-        found_node_id, found_endpoint = next(iter(nearest.items()))
-        assert found_node_id == query_id and found_endpoint == ref_endpoint
+        found_node_id, found_peer_id = next(iter(nearest.items()))
+        assert found_node_id == query_id and found_peer_id == ref_peer_id
 
     # test 3: find neighbors to random nodes
     accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
@@ -236,7 +236,7 @@ def test_dht_node():
     # test 5: node without peers
     detached_node = loop.run_until_complete(DHTNode.create())
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
 
@@ -404,13 +404,13 @@ async def test_dhtnode_blacklist():
 
     assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
 
-    assert set(node2.blacklist.ban_counter.keys()) == {node3.endpoint, node4.endpoint}
+    assert set(node2.blacklist.ban_counter.keys()) == {node3.peer_id, node4.peer_id}
 
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node3.endpoint in node1.blacklist
+    assert node3.peer_id in node1.blacklist
 
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
-    assert node2.endpoint not in node1.blacklist
+    assert node2.peer_id not in node1.blacklist
 
     await asyncio.gather(node1.shutdown(), node2.shutdown())
 

+ 9 - 9
tests/test_routing.py

@@ -76,7 +76,7 @@ def test_routing_table_parameters():
         for phony_neighbor_port in random.sample(range(1_000_000), 10_000):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
         for bucket in routing_table.buckets:
-            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_endpoint) <= bucket.size
+            assert len(bucket.replacement_nodes) == 0 or len(bucket.nodes_to_peer_id) <= bucket.size
         assert min_nbuckets <= len(routing_table.buckets) <= max_nbuckets, (
             f"Unexpected number of buckets: {min_nbuckets} <= {len(routing_table.buckets)} <= {max_nbuckets}")
 
@@ -92,34 +92,34 @@ def test_routing_table_search():
 
         for phony_neighbor_port in random.sample(range(1_000_000), table_size):
             routing_table.add_or_update_node(DHTID.generate(), f'{LOCALHOST}:{phony_neighbor_port}')
-            new_total = sum(len(bucket.nodes_to_endpoint) for bucket in routing_table.buckets)
+            new_total = sum(len(bucket.nodes_to_peer_id) for bucket in routing_table.buckets)
             num_added += new_total > total_nodes
             total_nodes = new_total
         num_replacements = sum(len(bucket.replacement_nodes) for bucket in routing_table.buckets)
-    
+
         all_active_neighbors = list(chain(
-            *(bucket.nodes_to_endpoint.keys() for bucket in routing_table.buckets)
+            *(bucket.nodes_to_peer_id.keys() for bucket in routing_table.buckets)
         ))
         assert lower_active <= len(all_active_neighbors) <= upper_active
         assert len(all_active_neighbors) == num_added
         assert num_added + num_replacements == table_size
-    
+
         # random queries
         for i in range(1000):
             k = random.randint(1, 100)
             query_id = DHTID.generate()
             exclude = query_id if random.random() < 0.5 else None
-            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
+            our_knn, our_peer_ids = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=exclude))
             reference_knn = heapq.nsmallest(k, all_active_neighbors, key=query_id.xor_distance)
             assert all(our == ref for our, ref in zip_longest(our_knn, reference_knn))
-            assert all(our_endpoint == routing_table[our_node]
-                       for our_node, our_endpoint in zip(our_knn, our_endpoints))
+            assert all(our_peer_id == routing_table[our_node]
+                       for our_node, our_peer_id in zip(our_knn, our_peer_ids))
 
         # queries from table
         for i in range(1000):
             k = random.randint(1, 100)
             query_id = random.choice(all_active_neighbors)
-            our_knn, our_endpoints = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
+            our_knn, our_peer_ids = zip(*routing_table.get_nearest_neighbors(query_id, k=k, exclude=query_id))
 
             reference_knn = heapq.nsmallest(k + 1, all_active_neighbors, key=query_id.xor_distance)
             if query_id in reference_knn:

+ 4 - 3
tests/test_utils/dht_swarms.py

@@ -7,7 +7,8 @@ from typing import Dict, List, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.dht.node import DHTID, Endpoint, DHTNode
+from hivemind.dht.node import DHTID, DHTNode
+from hivemind.p2p import PeerID
 
 
 def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
@@ -19,7 +20,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
     node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10))
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
 
-    info_queue.put((node.node_id, node.endpoint, maddrs))
+    info_queue.put((node.node_id, node.peer_id, maddrs))
 
     async def shutdown():
         await node.shutdown()
@@ -30,7 +31,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
 
 
 def launch_swarm_in_separate_processes(n_peers: int, n_sequential_peers: int) -> \
-        Tuple[List[mp.Process], Dict[Endpoint, DHTID], List[List[Multiaddr]]]:
+        Tuple[List[mp.Process], Dict[PeerID, DHTID], List[List[Multiaddr]]]:
     assert n_sequential_peers < n_peers, \
         'Parameters imply that first n_sequential_peers of n_peers will be run sequentially'