Browse Source

feat DHT: use P2P as backend (#208)

Relates: #185
Ilya 4 years ago
parent
commit
43ef3c6465
6 changed files with 210 additions and 122 deletions
  1. 5 1
      hivemind/dht/node.py
  2. 56 30
      hivemind/dht/protocol.py
  3. 1 1
      hivemind/p2p/__init__.py
  4. 46 10
      hivemind/p2p/p2p_daemon.py
  5. 76 55
      tests/test_dht_node.py
  6. 26 25
      tests/test_p2p_daemon.py

+ 5 - 1
hivemind/dht/node.py

@@ -141,6 +141,7 @@ class DHTNode:
                                                  parallel_rpc, cache_size, listen, listen_on, endpoint, record_validator,
                                                  parallel_rpc, cache_size, listen, listen_on, endpoint, record_validator,
                                                  **kwargs)
                                                  **kwargs)
         self.port = self.protocol.port
         self.port = self.protocol.port
+        self.endpoint = self.protocol.client.endpoint
 
 
         if initial_peers:
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
             # stage 1: ping initial_peers, add each other to the routing table
@@ -181,6 +182,9 @@ class DHTNode:
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         super().__init__()
         super().__init__()
 
 
+    def __del__(self):
+        self.protocol.__del__()
+
     async def shutdown(self, timeout=None):
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         """ Process existing requests, close all connections and stop the server """
         self.is_alive = False
         self.is_alive = False
@@ -233,7 +237,7 @@ class DHTNode:
         for query, nearest_nodes in nearest_nodes_per_query.items():
         for query, nearest_nodes in nearest_nodes_per_query.items():
             if not exclude_self:
             if not exclude_self:
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
-                node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
+                node_to_endpoint[self.node_id] = self.endpoint
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_endpoints
         return nearest_nodes_with_endpoints
 
 

+ 56 - 30
hivemind/dht/protocol.py

@@ -2,6 +2,7 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 import asyncio
 import asyncio
+import functools
 from itertools import zip_longest
 from itertools import zip_longest
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 
 
@@ -11,7 +12,8 @@ from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
+from hivemind.p2p import P2P, P2PContext
+from hivemind.utils import Endpoint, get_logger, replace_port, get_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -20,7 +22,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
+    channel_options: Tuple[Tuple[str, Any]]; server: P2P
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     record_validator: Optional[RecordValidatorBase]
     record_validator: Optional[RecordValidatorBase]
     # fmt:on
     # fmt:on
@@ -28,6 +30,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
     RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
 
 
+    PING_NAME, STORE_NAME, FIND_NAME = '__ping__', '__store__', '__find__'
+
     @classmethod
     @classmethod
     async def create(
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -55,18 +59,23 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.record_validator = record_validator
         self.record_validator = record_validator
 
 
+        self.client = await P2P.create(host_port=get_port(listen_on))
         if listen:  # set up server to process incoming rpc requests
         if listen:  # set up server to process incoming rpc requests
-            grpc.aio.init_grpc_aio()
-            self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            dht_grpc.add_DHTServicer_to_server(self, self.server)
-
-            self.port = self.server.add_insecure_port(listen_on)
+            self.server = self.client  #TODO deduplicate with client
+            await self.server.add_unary_handler(
+                DHTProtocol.PING_NAME, functools.partial(DHTProtocol.rpc_ping, self),
+                dht_pb2.PingRequest, dht_pb2.PingResponse)
+            await self.server.add_unary_handler(
+                DHTProtocol.STORE_NAME, functools.partial(DHTProtocol.rpc_store, self),
+                dht_pb2.StoreRequest, dht_pb2.StoreResponse)
+            await self.server.add_unary_handler(
+                DHTProtocol.FIND_NAME, functools.partial(DHTProtocol.rpc_find, self),
+                dht_pb2.FindRequest, dht_pb2.FindResponse)
+
+            self.port = self.server._host_port
             assert self.port != 0, f"Failed to listen to {listen_on}"
             assert self.port != 0, f"Failed to listen to {listen_on}"
-            if endpoint is not None and endpoint.endswith('*'):
-                endpoint = replace_port(endpoint, self.port)
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
-                                              endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
-            await self.server.start()
+                                              endpoint=endpoint or self.server.endpoint)
         else:  # not listening to incoming requests, client-only mode
         else:  # not listening to incoming requests, client-only mode
             # note: use empty node_info so peers won't add you to their routing tables
             # note: use empty node_info so peers won't add you to their routing tables
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
@@ -80,16 +89,36 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
         super().__init__()
 
 
+    def __del__(self):
+        self.client.__del__()
+
     async def shutdown(self, timeout=None):
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         """ Process existing requests, close all connections and stop the server """
         if self.server:
         if self.server:
-            await self.server.stop(timeout)
+            await self.server.stop_listening()
         else:
         else:
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
 
-    def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
+    class DHTStub:
+        def __init__(self, protocol: DHTProtocol, peer: Endpoint):
+            self.protocol = protocol
+            self.peer = peer
+
+        async def rpc_ping(self, request: dht_pb2.PingRequest, timeout=None) -> dht_pb2.PingResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.PING_NAME, request, dht_pb2.PingResponse)
+
+        async def rpc_store(self, request: dht_pb2.StoreRequest, timeout=None) -> dht_pb2.StoreResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.STORE_NAME, request, dht_pb2.StoreResponse)
+
+        async def rpc_find(self, request: dht_pb2.FindRequest, timeout=None) -> dht_pb2.FindResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.FIND_NAME, request, dht_pb2.FindResponse)
+
+    def _get_dht_stub(self, peer: Endpoint) -> DHTProtocol.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
         """ get a DHTStub that sends requests to a given peer """
-        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        return DHTProtocol.DHTStub(self, peer)
 
 
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
         """
@@ -107,8 +136,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 time_requested = get_dht_time()
                 time_requested = get_dht_time()
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 time_responded = get_dht_time()
                 time_responded = get_dht_time()
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {e}")
             response = None
             response = None
         responded = bool(response and response.peer and response.peer.node_id)
         responded = bool(response and response.peer and response.peer.node_id)
 
 
@@ -141,10 +170,10 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
                 if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
                     return response.sender_endpoint
                     return response.sender_endpoint
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {e}")
 
 
-    async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
+    async def rpc_ping(self, request: dht_pb2.PingRequest, context: P2PContext):
         """ Some node wants us to add it to our routing table. """
         """ Some node wants us to add it to our routing table. """
         response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
         response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
                                         dht_time=get_dht_time(), available=False)
                                         dht_time=get_dht_time(), available=False)
@@ -152,15 +181,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if request.peer and request.peer.node_id and request.peer.rpc_port:
         if request.peer and request.peer.node_id and request.peer.rpc_port:
             sender_id = DHTID.from_bytes(request.peer.node_id)
             sender_id = DHTID.from_bytes(request.peer.node_id)
             if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
             if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
-                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
-            else:
-                sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
+                response.sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
 
 
-            response.sender_endpoint = sender_endpoint
             if request.validate:
             if request.validate:
                 response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
                 response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
 
 
-            asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
+            asyncio.create_task(self.update_routing_table(sender_id, response.sender_endpoint,
                                                           responded=response.available or not request.validate))
                                                           responded=response.available or not request.validate))
 
 
         return response
         return response
@@ -215,12 +241,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
             return response.store_ok
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to store at {peer}: {e}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
             return None
 
 
-    async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
+    async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
@@ -300,11 +326,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
                     logger.error(f"Unknown result type: {result.type}")
                     logger.error(f"Unknown result type: {result.type}")
 
 
             return output
             return output
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to find at {peer}: {e}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
 
-    async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
+    async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
         """
         """
         Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
         Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1 +1 @@
-from hivemind.p2p.p2p_daemon import P2P
+from hivemind.p2p.p2p_daemon import P2P, P2PContext

+ 46 - 10
hivemind/p2p/p2p_daemon.py

@@ -2,6 +2,7 @@ import asyncio
 import copy
 import copy
 from pathlib import Path
 from pathlib import Path
 import pickle
 import pickle
+import signal
 import subprocess
 import subprocess
 import typing as tp
 import typing as tp
 import warnings
 import warnings
@@ -10,6 +11,7 @@ import google.protobuf
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
 
 
 from hivemind.utils.networking import find_open_port
 from hivemind.utils.networking import find_open_port
 
 
@@ -22,6 +24,9 @@ class P2PContext(object):
         self.ours_port = ours_port
         self.ours_port = ours_port
         self.handle_name = handle_name
         self.handle_name = handle_name
 
 
+    def peer(self) -> str:
+        return self.peer_id.to_base58()
+
 
 
 class P2P(object):
 class P2P(object):
     """
     """
@@ -31,7 +36,7 @@ class P2P(object):
     """
     """
 
 
     P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
     P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
-    NUM_RETRIES = 3
+    NUM_RETRIES = 4
     RETRY_DELAY = 0.4
     RETRY_DELAY = 0.4
     HEADER_LEN = 8
     HEADER_LEN = 8
     BYTEORDER = 'big'
     BYTEORDER = 'big'
@@ -106,7 +111,8 @@ class P2P(object):
     async def _identify_client(self, delay):
     async def _identify_client(self, delay):
         await asyncio.sleep(delay)
         await asyncio.sleep(delay)
         encoded = await self._client.identify()
         encoded = await self._client.identify()
-        self.id = encoded[0].to_base58()
+        self.id = encoded[0]
+        self.endpoint = self.id.to_base58()
 
 
     def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
     def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
         self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
         self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
@@ -223,12 +229,16 @@ class P2P(object):
 
 
         return do_handle_unary_stream
         return do_handle_unary_stream
 
 
-    def start_listening(self):
+    async def start_listening(self):
+        started = asyncio.Event()
+
         async def listen():
         async def listen():
             async with self._client.listen():
             async with self._client.listen():
+                started.set()
                 await self._server_stopped.wait()
                 await self._server_stopped.wait()
 
 
         self._listen_task = asyncio.create_task(listen())
         self._listen_task = asyncio.create_task(listen())
+        await started.wait()
 
 
     async def stop_listening(self):
     async def stop_listening(self):
         if self._listen_task is not None:
         if self._listen_task is not None:
@@ -242,25 +252,44 @@ class P2P(object):
 
 
     async def add_stream_handler(self, name, handle):
     async def add_stream_handler(self, name, handle):
         if self._listen_task is None:
         if self._listen_task is None:
-            self.start_listening()
+            await self.start_listening()
         await self._client.stream_handler(name, P2P._handle_stream(handle))
         await self._client.stream_handler(name, P2P._handle_stream(handle))
 
 
     async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
     async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
         if self._listen_task is None:
         if self._listen_task is None:
-            self.start_listening()
+            await self.start_listening()
         context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
         context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
         await self._client.stream_handler(
         await self._client.stream_handler(
             name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
             name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
 
 
-    async def call_peer_handler(self, peer_id, handler_name, input_data):
-        libp2p_peer_id = ID.from_base58(peer_id)
-        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+    async def _call_handler(self, peer_endpoint, handle_name, input_data: bytes) -> bytes:
+        peer_id = ID.from_base58(peer_endpoint)
+        for try_count in range(P2P.NUM_RETRIES):
+            try:
+                stream_info, reader, writer = await self._client.stream_open(
+                    peer_id, (handle_name,))
+            except ControlFailure:
+                if try_count == P2P.NUM_RETRIES - 1:
+                    raise
+                await asyncio.sleep(P2P.RETRY_DELAY * (2 ** try_count))
         try:
         try:
-            await P2P.send_data(input_data, writer)
-            return await P2P.receive_data(reader)
+            await P2P.send_raw_data(input_data, writer)
+            return await P2P.receive_raw_data(reader)
         finally:
         finally:
             writer.close()
             writer.close()
 
 
+    async def call_peer_handler(self, peer_endpoint, handle_name, request):
+        response_data = await self._call_handler(peer_endpoint, handle_name, pickle.dumps(request))
+        return pickle.loads(response_data)
+
+    async def call_unary_handler(self, peer_endpoint, handle_name, request_protobuf,
+                                 response_proto_type):
+        response_data = await self._call_handler(peer_endpoint, handle_name,
+                                                 request_protobuf.SerializeToString())
+        response = response_proto_type()
+        response.ParseFromString(response_data)
+        return response
+
     def __del__(self):
     def __del__(self):
         self._kill_child()
         self._kill_child()
 
 
@@ -269,6 +298,13 @@ class P2P(object):
             self._child.kill()
             self._child.kill()
             self._child.wait()
             self._child.wait()
 
 
+    async def wait_for_termination(self):
+        def _handler(signum, frame):
+            self._kill_child()
+
+        signal.signal(signal.SIGTERM, _handler)
+        await asyncio.Event().wait()
+
     def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
     def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
         proc_args = []
         proc_args = []
         proc_args.extend(
         proc_args.extend(

+ 76 - 55
tests/test_dht_node.py

@@ -16,37 +16,43 @@ from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
 from hivemind.dht.storage import DictionaryDHTValue
 
 
 
 
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
+def run_protocol_listener(port: int, dhtid: DHTID, pipe_side: mp.connection.Connection, ping: Optional[Endpoint] = None):
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     protocol = loop.run_until_complete(DHTProtocol.create(
     protocol = loop.run_until_complete(DHTProtocol.create(
         dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
         dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
 
 
-    assert protocol.port == port
+    port = protocol.port
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
 
 
     if ping is not None:
     if ping is not None:
         loop.run_until_complete(protocol.call_ping(ping))
         loop.run_until_complete(protocol.call_ping(ping))
-    started.set()
+
+    pipe_side.send((protocol.port, protocol.server.endpoint))
+
     loop.run_until_complete(protocol.server.wait_for_termination())
     loop.run_until_complete(protocol.server.wait_for_termination())
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
 
 
-# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
-# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
+# note: we run network-related tests in a separate process to re-initialize all global states from scratch
+# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in a sequence
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_dht_protocol():
 def test_dht_protocol():
     # create the first peer
     # create the first peer
-    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
-    peer1_proc.start(), peer1_started.wait()
+    first_side, ours_side = mp.Pipe()
+    peer1_port, peer1_id = hivemind.find_open_port(), DHTID.generate()
+    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, first_side), daemon=True)
+    peer1_proc.start()
+    peer1_port, peer1_endpoint = ours_side.recv()
 
 
     # create another peer that connects to the first peer
     # create another peer that connects to the first peer
-    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
-                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
-    peer2_proc.start(), peer2_started.wait()
+    second_side, ours_side = mp.Pipe()
+    peer2_port, peer2_id = hivemind.find_open_port(), DHTID.generate()
+    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, second_side),
+                            kwargs={'ping': peer1_endpoint}, daemon=True)
+    peer2_proc.start()
+    peer2_port, peer2_endpoint = ours_side.recv()
 
 
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
@@ -54,21 +60,21 @@ def test_dht_protocol():
             DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
             DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
         print(f"Self id={protocol.node_id}", flush=True)
         print(f"Self id={protocol.node_id}", flush=True)
 
 
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+        assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
 
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         store_ok = loop.run_until_complete(protocol.call_store(
         store_ok = loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )
         )
         assert all(store_ok), "DHT rejected a trivial store"
         assert all(store_ok), "DHT rejected a trivial store"
 
 
         # peer 1 must know about peer 2
         # peer 1 must know about peer 2
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+            protocol.call_find(peer1_endpoint, [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
         (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
-            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
+            f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
 
 
         assert recv_value == value and recv_expiration == expiration, \
         assert recv_value == value and recv_expiration == expiration, \
             f"call_find_value expected {value} (expires by {expiration}) " \
             f"call_find_value expected {value} (expires by {expiration}) " \
@@ -77,11 +83,11 @@ def test_dht_protocol():
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         dummy_key = DHTID.generate()
         dummy_key = DHTID.generate()
         empty_item, nodes_found_2 = loop.run_until_complete(
         empty_item, nodes_found_2 = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+            protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
         assert empty_item is None, "Non-existent keys shouldn't have values"
         assert empty_item is None, "Non-existent keys shouldn't have values"
         (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
         (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
-            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
+            f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
 
 
         # cause a non-response by querying a nonexistent peer
         # cause a non-response by querying a nonexistent peer
         dummy_port = hivemind.find_open_port()
         dummy_port = hivemind.find_open_port()
@@ -91,35 +97,38 @@ def test_dht_protocol():
         nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
         nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
         value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
         value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
         assert loop.run_until_complete(protocol.call_store(
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
             expiration_time=[expiration], subkeys=[subkey1])
             expiration_time=[expiration], subkeys=[subkey1])
         )
         )
         assert loop.run_until_complete(protocol.call_store(
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
             expiration_time=[expiration + 5], subkeys=[subkey2])
             expiration_time=[expiration + 5], subkeys=[subkey2])
         )
         )
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+            protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
 
-        assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
+        assert protocol.client.endpoint == loop.run_until_complete(protocol.get_outgoing_request_endpoint(peer1_endpoint))
 
 
         if listen:
         if listen:
             loop.run_until_complete(protocol.shutdown())
             loop.run_until_complete(protocol.shutdown())
 
 
     peer1_proc.terminate()
     peer1_proc.terminate()
     peer2_proc.terminate()
     peer2_proc.terminate()
+    protocol.__del__() #TODO
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 def test_empty_table():
 def test_empty_table():
     """ Test RPC methods with empty routing table """
     """ Test RPC methods with empty routing table """
-    peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
-    peer_proc.start(), peer_started.wait()
+    theirs_side, ours_side = mp.Pipe()
+    peer_port, peer_id = hivemind.find_open_port(), DHTID.generate()
+    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, theirs_side), daemon=True)
+    peer_proc.start()
+    peer_port, peer_endpoint = ours_side.recv()
 
 
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     protocol = loop.run_until_complete(DHTProtocol.create(
     protocol = loop.run_until_complete(DHTProtocol.create(
@@ -128,21 +137,22 @@ def test_empty_table():
     key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
     key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
 
     empty_item, nodes_found = loop.run_until_complete(
     empty_item, nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     assert empty_item is None and len(nodes_found) == 0
     assert empty_item is None and len(nodes_found) == 0
     assert all(loop.run_until_complete(protocol.call_store(
     assert all(loop.run_until_complete(protocol.call_store(
-        f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
     )), "peer rejected store"
     )), "peer rejected store"
 
 
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert len(nodes_found) == 0
     assert recv_value == value and recv_expiration == expiration
     assert recv_value == value and recv_expiration == expiration
 
 
-    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     peer_proc.terminate()
     peer_proc.terminate()
+    protocol.__del__() #TODO
 
 
 
 
 def run_node(node_id, peers, status_pipe: mp.Pipe):
 def run_node(node_id, peers, status_pipe: mp.Pipe):
@@ -151,9 +161,8 @@ def run_node(node_id, peers, status_pipe: mp.Pipe):
         asyncio.set_event_loop(asyncio.new_event_loop())
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
     node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
     node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
-    status_pipe.send(node.port)
-    while True:
-        loop.run_forever()
+    status_pipe.send((node.port, node.endpoint))
+    loop.run_until_complete(node.protocol.server.wait_for_termination())
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -168,17 +177,17 @@ def test_dht_node():
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
         proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
         proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
         proc.start()
         proc.start()
-        port = pipe_recv.recv()
+        port, endpoint = pipe_recv.recv()
         processes.append(proc)
         processes.append(proc)
-        dht[f"{LOCALHOST}:{port}"] = node_id
+        dht[endpoint] = node_id
 
 
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
-    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), min(len(dht), 5)), parallel_rpc=2,
                                                 cache_refresh_before_expiry=False))
                                                 cache_refresh_before_expiry=False))
 
 
     # test 1: find self
     # test 1: find self
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-    assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
+    assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
 
 
     # test 2: find others
     # test 2: find others
     for i in range(10):
     for i in range(10):
@@ -186,7 +195,7 @@ def test_dht_node():
         nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         assert len(nearest) == 1
         found_node_id, found_endpoint = next(iter(nearest.items()))
         found_node_id, found_endpoint = next(iter(nearest.items()))
-        assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
+        assert found_node_id == query_id and found_endpoint == ref_endpoint
 
 
     # test 3: find neighbors to random nodes
     # test 3: find neighbors to random nodes
     accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
     accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
@@ -195,7 +204,7 @@ def test_dht_node():
 
 
     for i in range(10):
     for i in range(10):
         query_id = DHTID.generate()
         query_id = DHTID.generate()
-        k_nearest = random.randint(1, 10)
+        k_nearest = random.randint(1, len(dht))
         exclude_self = random.random() > 0.5
         exclude_self = random.random() > 0.5
         nearest = loop.run_until_complete(
         nearest = loop.run_until_complete(
             me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
             me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
@@ -233,7 +242,7 @@ def test_dht_node():
     # test 5: node without peers
     # test 5: node without peers
     detached_node = loop.run_until_complete(DHTNode.create())
     detached_node = loop.run_until_complete(DHTNode.create())
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-    assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
     assert len(nearest) == 0
 
 
@@ -286,6 +295,9 @@ def test_dht_node():
 
 
     for proc in processes:
     for proc in processes:
         proc.terminate()
         proc.terminate()
+    me.__del__()#TODO
+    detached_node.__del__()
+    that_guy.__del__()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -314,12 +326,15 @@ async def test_dhtnode_replicas():
     assert await you.store('key2', 'baz', get_dht_time() + 1000)
     assert await you.store('key2', 'baz', get_dht_time() + 1000)
     assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
     assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
 
 
+    for p in peers:
+        p.__del__()#TODO
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dhtnode_caching(T=0.05):
 async def test_dhtnode_caching(T=0.05):
     node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
     node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+    node1 = await hivemind.DHTNode.create(initial_peers=[node2.endpoint],
                                           cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
                                           cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
     await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
@@ -359,15 +374,17 @@ async def test_dhtnode_caching(T=0.05):
     assert len(node1.cache_refresh_queue) == 0
     assert len(node1.cache_refresh_queue) == 0
 
 
     await asyncio.gather(node1.shutdown(), node2.shutdown())
     await asyncio.gather(node1.shutdown(), node2.shutdown())
+    node1.__del__()#TODO
+    node2.__del__()
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dhtnode_reuse_get():
 async def test_dhtnode_reuse_get():
     peers = []
     peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+    for i in range(5):
+        neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=32))
 
 
     await asyncio.gather(
     await asyncio.gather(
         random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
         random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
@@ -393,14 +410,17 @@ async def test_dhtnode_reuse_get():
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
 
 
+    for p in peers:
+        p.__del__()#TODO
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dhtnode_blacklist():
 async def test_dhtnode_blacklist():
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
-    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
+    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
+    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
 
 
     assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
     assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
     assert len(node2.blacklist.ban_counter) == 0
     assert len(node2.blacklist.ban_counter) == 0
@@ -413,25 +433,26 @@ async def test_dhtnode_blacklist():
     assert len(node2.blacklist.ban_counter) == 2
     assert len(node2.blacklist.ban_counter) == 2
 
 
     for banned_peer in node2.blacklist.ban_counter:
     for banned_peer in node2.blacklist.ban_counter:
-        assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
+        assert any(banned_peer == endpoint for endpoint in [node3.endpoint, node4.endpoint])
 
 
-    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node3_endpoint = replace_port(node3_endpoint, node3.port)
+    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(node1.endpoint)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node3_endpoint in node1.blacklist
     assert node3_endpoint in node1.blacklist
 
 
-    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node2_endpoint = replace_port(node2_endpoint, node2.port)
+    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(node1.endpoint)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node2_endpoint not in node1.blacklist
     assert node2_endpoint not in node1.blacklist
 
 
+    for node in [node1, node2, node3, node4]:
+        node.__del__()#TODO
+
 
 
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     with pytest.raises(ValidationError):
     with pytest.raises(ValidationError):
-        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
+        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint],
                                               endpoint=fake_endpoint)
                                               endpoint=fake_endpoint)
 
 
 
 
@@ -440,7 +461,7 @@ async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
 async def test_dhtnode_edge_cases():
 async def test_dhtnode_edge_cases():
     peers = []
     peers = []
     for i in range(5):
     for i in range(5):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
         peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
         peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
 
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]
     subkeys = [0, '', False, True, 'abyrvalg', 4555]

+ 26 - 25
tests/test_p2p_daemon.py

@@ -89,8 +89,8 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
             handler_cancelled = True
             handler_cancelled = True
         return dht_pb2.PingResponse(
         return dht_pb2.PingResponse(
             peer=dht_pb2.NodeInfo(
             peer=dht_pb2.NodeInfo(
-                node_id=context.ours_id.encode(), rpc_port=context.ours_port),
-            sender_endpoint=context.handle_name, available=True)
+                node_id=context.ours_id.to_bytes(), rpc_port=context.ours_port),
+            sender_endpoint=context.peer(), available=True)
 
 
     server_primary = await P2P.create()
     server_primary = await P2P.create()
     server = await replicate_if_needed(server_primary, replicate)
     server = await replicate_if_needed(server_primary, replicate)
@@ -105,24 +105,24 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
 
 
     ping_request = dht_pb2.PingRequest(
     ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client._host_port),
         validate=True)
         validate=True)
     expected_response = dht_pb2.PingResponse(
     expected_response = dht_pb2.PingResponse(
-        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
-        sender_endpoint=handle_name, available=True)
+        peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server._host_port),
+        sender_endpoint=client.endpoint, available=True)
 
 
     await asyncio.sleep(1)
     await asyncio.sleep(1)
-    libp2p_server_id = ID.from_base58(server.id)
-    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
-
-    await P2P.send_raw_data(ping_request.SerializeToString(), writer)
 
 
     if should_cancel:
     if should_cancel:
+        stream_info, reader, writer = await client._client.stream_open(
+            server.id, (handle_name,))
+        await P2P.send_raw_data(ping_request.SerializeToString(), writer)
         writer.close()
         writer.close()
         await asyncio.sleep(1)
         await asyncio.sleep(1)
         assert handler_cancelled
         assert handler_cancelled
     else:
     else:
-        result = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+        result = await client.call_unary_handler(server.endpoint, handle_name, ping_request,
+                                                 dht_pb2.PingResponse)
         assert result == expected_response
         assert result == expected_response
         assert not handler_cancelled
         assert not handler_cancelled
 
 
@@ -154,8 +154,8 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
     client_pid = client._child.pid
     client_pid = client._child.pid
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
 
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
     assert result == handle(test_input)
     assert result == handle(test_input)
 
 
     server.__del__()
     server.__del__()
@@ -193,17 +193,18 @@ async def test_call_peer_different_processes():
     response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
     response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
     response_received.value = 0
     response_received.value = 0
 
 
-    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc = mp.Process(target=server_target,
+                      args=(handler_name, server_side, client_side, response_received))
     proc.start()
     proc.start()
 
 
     client = await P2P.create()
     client = await P2P.create()
     client_pid = client._child.pid
     client_pid = client._child.pid
     assert is_process_running(client_pid)
     assert is_process_running(client_pid)
 
 
-    await asyncio.sleep(1)
+    # await asyncio.sleep(1)
     peer_id = client_side.recv()
     peer_id = client_side.recv()
 
 
-    result = await client.call_peer_handler(peer_id, handler_name, test_input)
+    result = await client.call_peer_handler(peer_id.to_base58(), handler_name, test_input)
     assert np.allclose(result, handle_square(test_input))
     assert np.allclose(result, handle_square(test_input))
     response_received.value = 1
     response_received.value = 1
 
 
@@ -230,8 +231,8 @@ async def test_call_peer_numpy(test_input, handle, replicate, handler_name="hand
     client_primary = await P2P.create()
     client_primary = await P2P.create()
     client = await replicate_if_needed(client_primary, replicate)
     client = await replicate_if_needed(client_primary, replicate)
 
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
     assert np.allclose(result, handle(test_input))
     assert np.allclose(result, handle(test_input))
 
 
 
 
@@ -250,8 +251,8 @@ async def test_call_peer_error(replicate, handler_name="handle"):
     client_primary = await P2P.create()
     client_primary = await P2P.create()
     client = await replicate_if_needed(client_primary, replicate)
     client = await replicate_if_needed(client_primary, replicate)
 
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name,
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name,
                                             [np.zeros((2, 3)), np.zeros((3, 2))])
                                             [np.zeros((2, 3)), np.zeros((3, 2))])
     assert type(result) == ValueError
     assert type(result) == ValueError
 
 
@@ -262,7 +263,7 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
         return key
         return key
 
 
     server_primary = await P2P.create()
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_endpoint = server_primary.endpoint
     await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
     await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
 
 
     server_replica1 = await replicate_if_needed(server_primary, True)
     server_replica1 = await replicate_if_needed(server_primary, True)
@@ -273,13 +274,13 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
 
     client = await P2P.create()
     client = await P2P.create()
     await asyncio.sleep(1)
     await asyncio.sleep(1)
-    result = await client.call_peer_handler(server_id, handler_name, "")
+    result = await client.call_peer_handler(server_endpoint, handler_name, "")
     assert result == "primary"
     assert result == "primary"
 
 
-    result = await client.call_peer_handler(server_id, handler_name + "1", "")
+    result = await client.call_peer_handler(server_endpoint, handler_name + "1", "")
     assert result == "replica1"
     assert result == "replica1"
 
 
-    result = await client.call_peer_handler(server_id, handler_name + "2", "")
+    result = await client.call_peer_handler(server_endpoint, handler_name + "2", "")
     assert result == "replica2"
     assert result == "replica2"
 
 
     await server_replica1.stop_listening()
     await server_replica1.stop_listening()
@@ -287,9 +288,9 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
 
     # Primary does not handle replicas protocols
     # Primary does not handle replicas protocols
     with pytest.raises(P2P.IncompleteRead):
     with pytest.raises(P2P.IncompleteRead):
-        await client.call_peer_handler(server_id, handler_name + "1", "")
+        await client.call_peer_handler(server_endpoint, handler_name + "1", "")
     with pytest.raises(P2P.IncompleteRead):
     with pytest.raises(P2P.IncompleteRead):
-        await client.call_peer_handler(server_id, handler_name + "2", "")
+        await client.call_peer_handler(server_endpoint, handler_name + "2", "")
 
 
     await server_primary.stop_listening()
     await server_primary.stop_listening()
     server_primary.__del__()
     server_primary.__del__()