Pārlūkot izejas kodu

feat DHT: use P2P as backend (#208)

Relates: #185
Ilya 4 gadi atpakaļ
vecāks
revīzija
43ef3c6465

+ 5 - 1
hivemind/dht/node.py

@@ -141,6 +141,7 @@ class DHTNode:
                                                  parallel_rpc, cache_size, listen, listen_on, endpoint, record_validator,
                                                  **kwargs)
         self.port = self.protocol.port
+        self.endpoint = self.protocol.client.endpoint
 
         if initial_peers:
             # stage 1: ping initial_peers, add each other to the routing table
@@ -181,6 +182,9 @@ class DHTNode:
         assert _initialized_with_create, " Please use DHTNode.create coroutine to spawn new node instances "
         super().__init__()
 
+    def __del__(self):
+        self.protocol.__del__()
+
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         self.is_alive = False
@@ -233,7 +237,7 @@ class DHTNode:
         for query, nearest_nodes in nearest_nodes_per_query.items():
             if not exclude_self:
                 nearest_nodes = sorted(nearest_nodes + [self.node_id], key=query.xor_distance)
-                node_to_endpoint[self.node_id] = f"{LOCALHOST}:{self.port}"
+                node_to_endpoint[self.node_id] = self.endpoint
             nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
         return nearest_nodes_with_endpoints
 

+ 56 - 30
hivemind/dht/protocol.py

@@ -2,6 +2,7 @@
 from __future__ import annotations
 
 import asyncio
+import functools
 from itertools import zip_longest
 from typing import Optional, List, Tuple, Dict, Any, Sequence, Union, Collection
 
@@ -11,7 +12,8 @@ from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, DHTExpiration, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.proto import dht_pb2, dht_pb2_grpc as dht_grpc
-from hivemind.utils import Endpoint, get_logger, replace_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
+from hivemind.p2p import P2P, P2PContext
+from hivemind.utils import Endpoint, get_logger, replace_port, get_port, MSGPackSerializer, ChannelCache, ValueWithExpiration
 from hivemind.utils import get_dht_time, GRPC_KEEPALIVE_OPTIONS, MAX_DHT_TIME_DISCREPANCY_SECONDS
 
 logger = get_logger(__name__)
@@ -20,7 +22,7 @@ logger = get_logger(__name__)
 class DHTProtocol(dht_grpc.DHTServicer):
     # fmt:off
     node_id: DHTID; port: int; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
-    channel_options: Tuple[Tuple[str, Any]]; server: grpc.aio.Server
+    channel_options: Tuple[Tuple[str, Any]]; server: P2P
     storage: DHTLocalStorage; cache: DHTLocalStorage; routing_table: RoutingTable; rpc_semaphore: asyncio.Semaphore
     record_validator: Optional[RecordValidatorBase]
     # fmt:on
@@ -28,6 +30,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
     serializer = MSGPackSerializer  # used to pack/unpack DHT Values for transfer over network
     RESERVED_SUBKEYS = IS_REGULAR_VALUE, IS_DICTIONARY = serializer.dumps(None), b''
 
+    PING_NAME, STORE_NAME, FIND_NAME = '__ping__', '__store__', '__find__'
+
     @classmethod
     async def create(
             cls, node_id: DHTID, bucket_size: int, depth_modulo: int, num_replicas: int, wait_timeout: float,
@@ -55,18 +59,23 @@ class DHTProtocol(dht_grpc.DHTServicer):
         self.rpc_semaphore = asyncio.Semaphore(parallel_rpc if parallel_rpc is not None else float('inf'))
         self.record_validator = record_validator
 
+        self.client = await P2P.create(host_port=get_port(listen_on))
         if listen:  # set up server to process incoming rpc requests
-            grpc.aio.init_grpc_aio()
-            self.server = grpc.aio.server(**kwargs, options=GRPC_KEEPALIVE_OPTIONS)
-            dht_grpc.add_DHTServicer_to_server(self, self.server)
-
-            self.port = self.server.add_insecure_port(listen_on)
+            self.server = self.client  #TODO deduplicate with client
+            await self.server.add_unary_handler(
+                DHTProtocol.PING_NAME, functools.partial(DHTProtocol.rpc_ping, self),
+                dht_pb2.PingRequest, dht_pb2.PingResponse)
+            await self.server.add_unary_handler(
+                DHTProtocol.STORE_NAME, functools.partial(DHTProtocol.rpc_store, self),
+                dht_pb2.StoreRequest, dht_pb2.StoreResponse)
+            await self.server.add_unary_handler(
+                DHTProtocol.FIND_NAME, functools.partial(DHTProtocol.rpc_find, self),
+                dht_pb2.FindRequest, dht_pb2.FindResponse)
+
+            self.port = self.server._host_port
             assert self.port != 0, f"Failed to listen to {listen_on}"
-            if endpoint is not None and endpoint.endswith('*'):
-                endpoint = replace_port(endpoint, self.port)
             self.node_info = dht_pb2.NodeInfo(node_id=node_id.to_bytes(), rpc_port=self.port,
-                                              endpoint=endpoint or dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value)
-            await self.server.start()
+                                              endpoint=endpoint or self.server.endpoint)
         else:  # not listening to incoming requests, client-only mode
             # note: use empty node_info so peers won't add you to their routing tables
             self.node_info, self.server, self.port = dht_pb2.NodeInfo(), None, None
@@ -80,16 +89,36 @@ class DHTProtocol(dht_grpc.DHTServicer):
         assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
         super().__init__()
 
+    def __del__(self):
+        self.client.__del__()
+
     async def shutdown(self, timeout=None):
         """ Process existing requests, close all connections and stop the server """
         if self.server:
-            await self.server.stop(timeout)
+            await self.server.stop_listening()
         else:
             logger.warning("DHTProtocol has no server (due to listen=False), it doesn't need to be shut down")
 
-    def _get_dht_stub(self, peer: Endpoint) -> dht_grpc.DHTStub:
+    class DHTStub:
+        def __init__(self, protocol: DHTProtocol, peer: Endpoint):
+            self.protocol = protocol
+            self.peer = peer
+
+        async def rpc_ping(self, request: dht_pb2.PingRequest, timeout=None) -> dht_pb2.PingResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.PING_NAME, request, dht_pb2.PingResponse)
+
+        async def rpc_store(self, request: dht_pb2.StoreRequest, timeout=None) -> dht_pb2.StoreResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.STORE_NAME, request, dht_pb2.StoreResponse)
+
+        async def rpc_find(self, request: dht_pb2.FindRequest, timeout=None) -> dht_pb2.FindResponse:
+            return await self.protocol.client.call_unary_handler(
+                self.peer, DHTProtocol.FIND_NAME, request, dht_pb2.FindResponse)
+
+    def _get_dht_stub(self, peer: Endpoint) -> DHTProtocol.DHTStub:
         """ get a DHTStub that sends requests to a given peer """
-        return ChannelCache.get_stub(peer, dht_grpc.DHTStub, aio=True, options=self.channel_options)
+        return DHTProtocol.DHTStub(self, peer)
 
     async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool = True) -> Optional[DHTID]:
         """
@@ -107,8 +136,8 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 time_requested = get_dht_time()
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 time_responded = get_dht_time()
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {e}")
             response = None
         responded = bool(response and response.peer and response.peer.node_id)
 
@@ -141,10 +170,10 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 response = await self._get_dht_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
                 if response.sender_endpoint != dht_pb2.PingResponse.sender_endpoint.DESCRIPTOR.default_value:
                     return response.sender_endpoint
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to ping {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to ping {peer}: {e}")
 
-    async def rpc_ping(self, request: dht_pb2.PingRequest, context: grpc.ServicerContext):
+    async def rpc_ping(self, request: dht_pb2.PingRequest, context: P2PContext):
         """ Some node wants us to add it to our routing table. """
         response = dht_pb2.PingResponse(peer=self.node_info, sender_endpoint=context.peer(),
                                         dht_time=get_dht_time(), available=False)
@@ -152,15 +181,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
         if request.peer and request.peer.node_id and request.peer.rpc_port:
             sender_id = DHTID.from_bytes(request.peer.node_id)
             if request.peer.endpoint != dht_pb2.NodeInfo.endpoint.DESCRIPTOR.default_value:
-                sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
-            else:
-                sender_endpoint = replace_port(context.peer(), new_port=request.peer.rpc_port)
+                response.sender_endpoint = request.peer.endpoint  # if peer has preferred endpoint, use it
 
-            response.sender_endpoint = sender_endpoint
             if request.validate:
                 response.available = await self.call_ping(response.sender_endpoint, validate=False) == sender_id
 
-            asyncio.create_task(self.update_routing_table(sender_id, sender_endpoint,
+            asyncio.create_task(self.update_routing_table(sender_id, response.sender_endpoint,
                                                           responded=response.available or not request.validate))
 
         return response
@@ -215,12 +241,12 @@ class DHTProtocol(dht_grpc.DHTServicer):
                 peer_id = DHTID.from_bytes(response.peer.node_id)
                 asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
             return response.store_ok
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to store at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to store at {peer}: {e}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
             return None
 
-    async def rpc_store(self, request: dht_pb2.StoreRequest, context: grpc.ServicerContext) -> dht_pb2.StoreResponse:
+    async def rpc_store(self, request: dht_pb2.StoreRequest, context: P2PContext) -> dht_pb2.StoreResponse:
         """ Some node wants us to store this (key, value) pair """
         if request.peer:  # if requested, add peer to the routing table
             asyncio.create_task(self.rpc_ping(dht_pb2.PingRequest(peer=request.peer), context))
@@ -300,11 +326,11 @@ class DHTProtocol(dht_grpc.DHTServicer):
                     logger.error(f"Unknown result type: {result.type}")
 
             return output
-        except grpc.aio.AioRpcError as error:
-            logger.debug(f"DHTProtocol failed to find at {peer}: {error.code()}")
+        except Exception as e:
+            logger.debug(f"DHTProtocol failed to find at {peer}: {e}")
             asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
 
-    async def rpc_find(self, request: dht_pb2.FindRequest, context: grpc.ServicerContext) -> dht_pb2.FindResponse:
+    async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
         """
         Someone wants to find keys in the DHT. For all keys that we have locally, return value and expiration
         Also return :bucket_size: nearest neighbors from our routing table for each key (whether or not we found value)

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1 +1 @@
-from hivemind.p2p.p2p_daemon import P2P
+from hivemind.p2p.p2p_daemon import P2P, P2PContext

+ 46 - 10
hivemind/p2p/p2p_daemon.py

@@ -2,6 +2,7 @@ import asyncio
 import copy
 from pathlib import Path
 import pickle
+import signal
 import subprocess
 import typing as tp
 import warnings
@@ -10,6 +11,7 @@ import google.protobuf
 from multiaddr import Multiaddr
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
 
 from hivemind.utils.networking import find_open_port
 
@@ -22,6 +24,9 @@ class P2PContext(object):
         self.ours_port = ours_port
         self.handle_name = handle_name
 
+    def peer(self) -> str:
+        return self.peer_id.to_base58()
+
 
 class P2P(object):
     """
@@ -31,7 +36,7 @@ class P2P(object):
     """
 
     P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
-    NUM_RETRIES = 3
+    NUM_RETRIES = 4
     RETRY_DELAY = 0.4
     HEADER_LEN = 8
     BYTEORDER = 'big'
@@ -106,7 +111,8 @@ class P2P(object):
     async def _identify_client(self, delay):
         await asyncio.sleep(delay)
         encoded = await self._client.identify()
-        self.id = encoded[0].to_base58()
+        self.id = encoded[0]
+        self.endpoint = self.id.to_base58()
 
     def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
         self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
@@ -223,12 +229,16 @@ class P2P(object):
 
         return do_handle_unary_stream
 
-    def start_listening(self):
+    async def start_listening(self):
+        started = asyncio.Event()
+
         async def listen():
             async with self._client.listen():
+                started.set()
                 await self._server_stopped.wait()
 
         self._listen_task = asyncio.create_task(listen())
+        await started.wait()
 
     async def stop_listening(self):
         if self._listen_task is not None:
@@ -242,25 +252,44 @@ class P2P(object):
 
     async def add_stream_handler(self, name, handle):
         if self._listen_task is None:
-            self.start_listening()
+            await self.start_listening()
         await self._client.stream_handler(name, P2P._handle_stream(handle))
 
     async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
         if self._listen_task is None:
-            self.start_listening()
+            await self.start_listening()
         context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
         await self._client.stream_handler(
             name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
 
-    async def call_peer_handler(self, peer_id, handler_name, input_data):
-        libp2p_peer_id = ID.from_base58(peer_id)
-        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+    async def _call_handler(self, peer_endpoint, handle_name, input_data: bytes) -> bytes:
+        peer_id = ID.from_base58(peer_endpoint)
+        for try_count in range(P2P.NUM_RETRIES):
+            try:
+                stream_info, reader, writer = await self._client.stream_open(
+                    peer_id, (handle_name,))
+            except ControlFailure:
+                if try_count == P2P.NUM_RETRIES - 1:
+                    raise
+                await asyncio.sleep(P2P.RETRY_DELAY * (2 ** try_count))
         try:
-            await P2P.send_data(input_data, writer)
-            return await P2P.receive_data(reader)
+            await P2P.send_raw_data(input_data, writer)
+            return await P2P.receive_raw_data(reader)
         finally:
             writer.close()
 
+    async def call_peer_handler(self, peer_endpoint, handle_name, request):
+        response_data = await self._call_handler(peer_endpoint, handle_name, pickle.dumps(request))
+        return pickle.loads(response_data)
+
+    async def call_unary_handler(self, peer_endpoint, handle_name, request_protobuf,
+                                 response_proto_type):
+        response_data = await self._call_handler(peer_endpoint, handle_name,
+                                                 request_protobuf.SerializeToString())
+        response = response_proto_type()
+        response.ParseFromString(response_data)
+        return response
+
     def __del__(self):
         self._kill_child()
 
@@ -269,6 +298,13 @@ class P2P(object):
             self._child.kill()
             self._child.wait()
 
+    async def wait_for_termination(self):
+        def _handler(signum, frame):
+            self._kill_child()
+
+        signal.signal(signal.SIGTERM, _handler)
+        await asyncio.Event().wait()
+
     def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
         proc_args = []
         proc_args.extend(

+ 76 - 55
tests/test_dht_node.py

@@ -16,37 +16,43 @@ from hivemind.dht.protocol import DHTProtocol, ValidationError
 from hivemind.dht.storage import DictionaryDHTValue
 
 
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
+def run_protocol_listener(port: int, dhtid: DHTID, pipe_side: mp.connection.Connection, ping: Optional[Endpoint] = None):
     loop = asyncio.get_event_loop()
     protocol = loop.run_until_complete(DHTProtocol.create(
         dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
 
-    assert protocol.port == port
+    port = protocol.port
     print(f"Started peer id={protocol.node_id} port={port}", flush=True)
 
     if ping is not None:
         loop.run_until_complete(protocol.call_ping(ping))
-    started.set()
+
+    pipe_side.send((protocol.port, protocol.server.endpoint))
+
     loop.run_until_complete(protocol.server.wait_for_termination())
     print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
 
 
-# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
-# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
+# note: we run network-related tests in a separate process to re-initialize all global states from scratch
+# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in a sequence
 
 
 @pytest.mark.forked
 def test_dht_protocol():
     # create the first peer
-    peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
-    peer1_proc.start(), peer1_started.wait()
+    first_side, ours_side = mp.Pipe()
+    peer1_port, peer1_id = hivemind.find_open_port(), DHTID.generate()
+    peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, first_side), daemon=True)
+    peer1_proc.start()
+    peer1_port, peer1_endpoint = ours_side.recv()
 
     # create another peer that connects to the first peer
-    peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
-                            kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
-    peer2_proc.start(), peer2_started.wait()
+    second_side, ours_side = mp.Pipe()
+    peer2_port, peer2_id = hivemind.find_open_port(), DHTID.generate()
+    peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, second_side),
+                            kwargs={'ping': peer1_endpoint}, daemon=True)
+    peer2_proc.start()
+    peer2_port, peer2_endpoint = ours_side.recv()
 
     loop = asyncio.get_event_loop()
     for listen in [False, True]:  # note: order matters, this test assumes that first run uses listen=False
@@ -54,21 +60,21 @@ def test_dht_protocol():
             DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
         print(f"Self id={protocol.node_id}", flush=True)
 
-        assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
+        assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
 
         key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
         store_ok = loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+            peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
         )
         assert all(store_ok), "DHT rejected a trivial store"
 
         # peer 1 must know about peer 2
         (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
+            protocol.call_find(peer1_endpoint, [key]))[key]
         recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
         (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
-        assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
-            f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
+            f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
 
         assert recv_value == value and recv_expiration == expiration, \
             f"call_find_value expected {value} (expires by {expiration}) " \
@@ -77,11 +83,11 @@ def test_dht_protocol():
         # peer 2 must know about peer 1, but not have a *random* nonexistent value
         dummy_key = DHTID.generate()
         empty_item, nodes_found_2 = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
+            protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
         assert empty_item is None, "Non-existent keys shouldn't have values"
         (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
-        assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
-            f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
+        assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
+            f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
 
         # cause a non-response by querying a nonexistent peer
         dummy_port = hivemind.find_open_port()
@@ -91,35 +97,38 @@ def test_dht_protocol():
         nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
         value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
             expiration_time=[expiration], subkeys=[subkey1])
         )
         assert loop.run_until_complete(protocol.call_store(
-            f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
+            peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
             expiration_time=[expiration + 5], subkeys=[subkey2])
         )
         (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
-            protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
+            protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
         assert isinstance(recv_dict, DictionaryDHTValue)
         assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
         assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
         assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
 
-        assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
+        assert protocol.client.endpoint == loop.run_until_complete(protocol.get_outgoing_request_endpoint(peer1_endpoint))
 
         if listen:
             loop.run_until_complete(protocol.shutdown())
 
     peer1_proc.terminate()
     peer2_proc.terminate()
+    protocol.__del__() #TODO
 
 
 @pytest.mark.forked
 def test_empty_table():
     """ Test RPC methods with empty routing table """
-    peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
-    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
-    peer_proc.start(), peer_started.wait()
+    theirs_side, ours_side = mp.Pipe()
+    peer_port, peer_id = hivemind.find_open_port(), DHTID.generate()
+    peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, theirs_side), daemon=True)
+    peer_proc.start()
+    peer_port, peer_endpoint = ours_side.recv()
 
     loop = asyncio.get_event_loop()
     protocol = loop.run_until_complete(DHTProtocol.create(
@@ -128,21 +137,22 @@ def test_empty_table():
     key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
 
     empty_item, nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     assert empty_item is None and len(nodes_found) == 0
     assert all(loop.run_until_complete(protocol.call_store(
-        f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
+        peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
     )), "peer rejected store"
 
     (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
-        protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
+        protocol.call_find(peer_endpoint, [key]))[key]
     recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
     assert len(nodes_found) == 0
     assert recv_value == value and recv_expiration == expiration
 
-    assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
+    assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
     assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
     peer_proc.terminate()
+    protocol.__del__() #TODO
 
 
 def run_node(node_id, peers, status_pipe: mp.Pipe):
@@ -151,9 +161,8 @@ def run_node(node_id, peers, status_pipe: mp.Pipe):
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
     node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
-    status_pipe.send(node.port)
-    while True:
-        loop.run_forever()
+    status_pipe.send((node.port, node.endpoint))
+    loop.run_until_complete(node.protocol.server.wait_for_termination())
 
 
 @pytest.mark.forked
@@ -168,17 +177,17 @@ def test_dht_node():
         pipe_recv, pipe_send = mp.Pipe(duplex=False)
         proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
         proc.start()
-        port = pipe_recv.recv()
+        port, endpoint = pipe_recv.recv()
         processes.append(proc)
-        dht[f"{LOCALHOST}:{port}"] = node_id
+        dht[endpoint] = node_id
 
     loop = asyncio.get_event_loop()
-    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
+    me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), min(len(dht), 5)), parallel_rpc=2,
                                                 cache_refresh_before_expiry=False))
 
     # test 1: find self
     nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
-    assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
+    assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
 
     # test 2: find others
     for i in range(10):
@@ -186,7 +195,7 @@ def test_dht_node():
         nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
         assert len(nearest) == 1
         found_node_id, found_endpoint = next(iter(nearest.items()))
-        assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
+        assert found_node_id == query_id and found_endpoint == ref_endpoint
 
     # test 3: find neighbors to random nodes
     accuracy_numerator = accuracy_denominator = 0  # top-1 nearest neighbor accuracy
@@ -195,7 +204,7 @@ def test_dht_node():
 
     for i in range(10):
         query_id = DHTID.generate()
-        k_nearest = random.randint(1, 10)
+        k_nearest = random.randint(1, len(dht))
         exclude_self = random.random() > 0.5
         nearest = loop.run_until_complete(
             me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
@@ -233,7 +242,7 @@ def test_dht_node():
     # test 5: node without peers
     detached_node = loop.run_until_complete(DHTNode.create())
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
-    assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
+    assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
     nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
     assert len(nearest) == 0
 
@@ -286,6 +295,9 @@ def test_dht_node():
 
     for proc in processes:
         proc.terminate()
+    me.__del__()#TODO
+    detached_node.__del__()
+    that_guy.__del__()
 
 
 @pytest.mark.forked
@@ -314,12 +326,15 @@ async def test_dhtnode_replicas():
     assert await you.store('key2', 'baz', get_dht_time() + 1000)
     assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
 
+    for p in peers:
+        p.__del__()#TODO
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_caching(T=0.05):
     node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
-    node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
+    node1 = await hivemind.DHTNode.create(initial_peers=[node2.endpoint],
                                           cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
     await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
     await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
@@ -359,15 +374,17 @@ async def test_dhtnode_caching(T=0.05):
     assert len(node1.cache_refresh_queue) == 0
 
     await asyncio.gather(node1.shutdown(), node2.shutdown())
+    node1.__del__()#TODO
+    node2.__del__()
 
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_reuse_get():
     peers = []
-    for i in range(10):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
-        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
+    for i in range(5):
+        neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
+        peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=32))
 
     await asyncio.gather(
         random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
@@ -393,14 +410,17 @@ async def test_dhtnode_reuse_get():
     assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
     assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
 
+    for p in peers:
+        p.__del__()#TODO
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_blacklist():
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
-    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
-    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
+    node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
+    node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
+    node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
 
     assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
     assert len(node2.blacklist.ban_counter) == 0
@@ -413,25 +433,26 @@ async def test_dhtnode_blacklist():
     assert len(node2.blacklist.ban_counter) == 2
 
     for banned_peer in node2.blacklist.ban_counter:
-        assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
+        assert any(banned_peer == endpoint for endpoint in [node3.endpoint, node4.endpoint])
 
-    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node3_endpoint = replace_port(node3_endpoint, node3.port)
+    node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(node1.endpoint)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node3_endpoint in node1.blacklist
 
-    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
-    node2_endpoint = replace_port(node2_endpoint, node2.port)
+    node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(node1.endpoint)
     assert await node1.get('abc', latest=True)  # force node1 to crawl dht and discover unresponsive peers
     assert node2_endpoint not in node1.blacklist
 
+    for node in [node1, node2, node3, node4]:
+        node.__del__()#TODO
+
 
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
     node1 = await hivemind.DHTNode.create(blacklist_time=999)
     with pytest.raises(ValidationError):
-        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
+        node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint],
                                               endpoint=fake_endpoint)
 
 
@@ -440,7 +461,7 @@ async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
 async def test_dhtnode_edge_cases():
     peers = []
     for i in range(5):
-        neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
+        neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
         peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
 
     subkeys = [0, '', False, True, 'abyrvalg', 4555]

+ 26 - 25
tests/test_p2p_daemon.py

@@ -89,8 +89,8 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
             handler_cancelled = True
         return dht_pb2.PingResponse(
             peer=dht_pb2.NodeInfo(
-                node_id=context.ours_id.encode(), rpc_port=context.ours_port),
-            sender_endpoint=context.handle_name, available=True)
+                node_id=context.ours_id.to_bytes(), rpc_port=context.ours_port),
+            sender_endpoint=context.peer(), available=True)
 
     server_primary = await P2P.create()
     server = await replicate_if_needed(server_primary, replicate)
@@ -105,24 +105,24 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     assert is_process_running(client_pid)
 
     ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client._host_port),
         validate=True)
     expected_response = dht_pb2.PingResponse(
-        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
-        sender_endpoint=handle_name, available=True)
+        peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server._host_port),
+        sender_endpoint=client.endpoint, available=True)
 
     await asyncio.sleep(1)
-    libp2p_server_id = ID.from_base58(server.id)
-    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
-
-    await P2P.send_raw_data(ping_request.SerializeToString(), writer)
 
     if should_cancel:
+        stream_info, reader, writer = await client._client.stream_open(
+            server.id, (handle_name,))
+        await P2P.send_raw_data(ping_request.SerializeToString(), writer)
         writer.close()
         await asyncio.sleep(1)
         assert handler_cancelled
     else:
-        result = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+        result = await client.call_unary_handler(server.endpoint, handle_name, ping_request,
+                                                 dht_pb2.PingResponse)
         assert result == expected_response
         assert not handler_cancelled
 
@@ -154,8 +154,8 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
     assert result == handle(test_input)
 
     server.__del__()
@@ -193,17 +193,18 @@ async def test_call_peer_different_processes():
     response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
     response_received.value = 0
 
-    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc = mp.Process(target=server_target,
+                      args=(handler_name, server_side, client_side, response_received))
     proc.start()
 
     client = await P2P.create()
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
-    await asyncio.sleep(1)
+    # await asyncio.sleep(1)
     peer_id = client_side.recv()
 
-    result = await client.call_peer_handler(peer_id, handler_name, test_input)
+    result = await client.call_peer_handler(peer_id.to_base58(), handler_name, test_input)
     assert np.allclose(result, handle_square(test_input))
     response_received.value = 1
 
@@ -230,8 +231,8 @@ async def test_call_peer_numpy(test_input, handle, replicate, handler_name="hand
     client_primary = await P2P.create()
     client = await replicate_if_needed(client_primary, replicate)
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name, test_input)
     assert np.allclose(result, handle(test_input))
 
 
@@ -250,8 +251,8 @@ async def test_call_peer_error(replicate, handler_name="handle"):
     client_primary = await P2P.create()
     client = await replicate_if_needed(client_primary, replicate)
 
-    await asyncio.sleep(1)
-    result = await client.call_peer_handler(server.id, handler_name,
+    # await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.endpoint, handler_name,
                                             [np.zeros((2, 3)), np.zeros((3, 2))])
     assert type(result) == ValueError
 
@@ -262,7 +263,7 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
         return key
 
     server_primary = await P2P.create()
-    server_id = server_primary.id
+    server_endpoint = server_primary.endpoint
     await server_primary.add_stream_handler(handler_name, partial(handler, key="primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)
@@ -273,13 +274,13 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
     client = await P2P.create()
     await asyncio.sleep(1)
-    result = await client.call_peer_handler(server_id, handler_name, "")
+    result = await client.call_peer_handler(server_endpoint, handler_name, "")
     assert result == "primary"
 
-    result = await client.call_peer_handler(server_id, handler_name + "1", "")
+    result = await client.call_peer_handler(server_endpoint, handler_name + "1", "")
     assert result == "replica1"
 
-    result = await client.call_peer_handler(server_id, handler_name + "2", "")
+    result = await client.call_peer_handler(server_endpoint, handler_name + "2", "")
     assert result == "replica2"
 
     await server_replica1.stop_listening()
@@ -287,9 +288,9 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
 
     # Primary does not handle replicas protocols
     with pytest.raises(P2P.IncompleteRead):
-        await client.call_peer_handler(server_id, handler_name + "1", "")
+        await client.call_peer_handler(server_endpoint, handler_name + "1", "")
     with pytest.raises(P2P.IncompleteRead):
-        await client.call_peer_handler(server_id, handler_name + "2", "")
+        await client.call_peer_handler(server_endpoint, handler_name + "2", "")
 
     await server_primary.stop_listening()
     server_primary.__del__()