import asyncio import functools import io import os import subprocess import time import uuid from contextlib import asynccontextmanager, AsyncExitStack from typing import NamedTuple from google.protobuf.message import EncodeError from multiaddr import Multiaddr, protocols import pytest from hivemind import find_open_port from hivemind.p2p.p2p_daemon_bindings.control import parse_conn_protocol, DaemonConnector, ControlClient from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, raise_if_failed, write_unsigned_varint, \ read_unsigned_varint, read_pbmsg_safe, write_pbmsg from hivemind.proto import p2pd_pb2 as p2pd_pb from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo, PeerInfo def test_raise_if_failed_raises(): resp = p2pd_pb.Response() resp.type = p2pd_pb.Response.ERROR with pytest.raises(ControlFailure): raise_if_failed(resp) def test_raise_if_failed_not_raises(): resp = p2pd_pb.Response() resp.type = p2pd_pb.Response.OK raise_if_failed(resp) pairs_int_varint_valid = ( (0, b"\x00"), (1, b"\x01"), (128, b"\x80\x01"), (2 ** 32, b"\x80\x80\x80\x80\x10"), (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"), ) pairs_int_varint_overflow = ( (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"), (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"), ( 2 ** 128, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04", ), ) class MockReader(io.BytesIO): async def readexactly(self, n): await asyncio.sleep(0) return self.read(n) class MockWriter(io.BytesIO): pass class MockReaderWriter(MockReader, MockWriter): pass @pytest.mark.parametrize("integer, var_integer", pairs_int_varint_valid) @pytest.mark.asyncio async def test_write_unsigned_varint(integer, var_integer): s = MockWriter() await write_unsigned_varint(s, integer) assert s.getvalue() == var_integer @pytest.mark.parametrize("integer", tuple(i[0] for i in pairs_int_varint_overflow)) @pytest.mark.asyncio async def test_write_unsigned_varint_overflow(integer): s = MockWriter() with pytest.raises(ValueError): await write_unsigned_varint(s, integer) @pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128))) @pytest.mark.asyncio async def test_write_unsigned_varint_negative(integer): s = MockWriter() with pytest.raises(ValueError): await write_unsigned_varint(s, integer) @pytest.mark.parametrize("integer, var_integer", pairs_int_varint_valid) @pytest.mark.asyncio async def test_read_unsigned_varint(integer, var_integer): s = MockReader(var_integer) result = await read_unsigned_varint(s) assert result == integer @pytest.mark.parametrize("var_integer", tuple(i[1] for i in pairs_int_varint_overflow)) @pytest.mark.asyncio async def test_read_unsigned_varint_overflow(var_integer): s = MockReader(var_integer) with pytest.raises(ValueError): await read_unsigned_varint(s) @pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128)) @pytest.mark.asyncio async def test_read_write_unsigned_varint_max_bits_edge(max_bits): """ Test the edge with different `max_bits` """ for i in range(-3, 0): integer = i + (2 ** max_bits) s = MockReaderWriter() await write_unsigned_varint(s, integer, max_bits=max_bits) s.seek(0, 0) result = await read_unsigned_varint(s, max_bits=max_bits) assert integer == result @pytest.fixture(scope="module") def peer_id_string(): return "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7" @pytest.fixture(scope="module") def peer_id_bytes(): return b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5= timeout: # timeout assert False, f"{coro_func} still failed after `{timeout}` seconds" await asyncio.sleep(0.01) class Daemon: control_maddr = None proc_daemon = None log_filename = "" f_log = None closed = None def __init__( self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub ): self.control_maddr = control_maddr self.enable_control = enable_control self.enable_connmgr = enable_connmgr self.enable_dht = enable_dht self.enable_pubsub = enable_pubsub self.is_closed = False self._start_logging() self._run() def _start_logging(self): name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_") self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt" self.f_log = open(self.log_filename, "wb") def _run(self): cmd_list = ["hivemind/hivemind_cli/p2pd", f"-listen={str(self.control_maddr)}"] cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"] if self.enable_connmgr: cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"] if self.enable_dht: cmd_list += ["-dht=true"] if self.enable_pubsub: cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"] self.proc_daemon = subprocess.Popen( cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0 ) async def wait_until_ready(self): lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") lines_head_occurred = {line: False for line in lines_head_pattern} with open(self.log_filename, "rb") as f_log_read: async def read_from_daemon_and_check(): line = f_log_read.readline() for head_pattern in lines_head_occurred: if line.startswith(head_pattern): lines_head_occurred[head_pattern] = True return all([value for _, value in lines_head_occurred.items()]) await try_until_success(read_from_daemon_and_check) # sleep for a while in case that the daemon haven't been ready after emitting these lines await asyncio.sleep(0.1) def close(self): if self.is_closed: return self.proc_daemon.terminate() self.proc_daemon.wait() self.f_log.close() self.is_closed = True class DaemonTuple(NamedTuple): daemon: Daemon client: Client class ConnectionFailure(Exception): pass @asynccontextmanager async def make_p2pd_pair_unix( enable_control, enable_connmgr, enable_dht, enable_pubsub ): name = str(uuid.uuid4())[:8] control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock") listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock") # Remove the existing unix socket files if they are existing try: os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX)) except FileNotFoundError: pass try: os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX)) except FileNotFoundError: pass async with _make_p2pd_pair( control_maddr=control_maddr, listen_maddr=listen_maddr, enable_control=enable_control, enable_connmgr=enable_connmgr, enable_dht=enable_dht, enable_pubsub=enable_pubsub, ) as pair: yield pair @asynccontextmanager async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub): control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}") listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}") async with _make_p2pd_pair( control_maddr=control_maddr, listen_maddr=listen_maddr, enable_control=enable_control, enable_connmgr=enable_connmgr, enable_dht=enable_dht, enable_pubsub=enable_pubsub, ) as pair: yield pair @asynccontextmanager async def _make_p2pd_pair( control_maddr, listen_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub, ): p2pd = Daemon( control_maddr=control_maddr, enable_control=enable_control, enable_connmgr=enable_connmgr, enable_dht=enable_dht, enable_pubsub=enable_pubsub, ) # wait for daemon ready await p2pd.wait_until_ready() client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr) try: async with client.listen(): yield DaemonTuple(daemon=p2pd, client=client) finally: if not p2pd.is_closed: p2pd.close() @pytest.fixture async def p2pcs( num_p2pds, enable_control, enable_connmgr, enable_dht, enable_pubsub, func_make_p2pd_pair, ): # TODO: Change back to gather style async with AsyncExitStack() as stack: p2pd_tuples = [ await stack.enter_async_context( func_make_p2pd_pair( enable_control=enable_control, enable_connmgr=enable_connmgr, enable_dht=enable_dht, enable_pubsub=enable_pubsub, ) ) for _ in range(num_p2pds) ] yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples) @pytest.mark.parametrize( "enable_control, func_make_p2pd_pair", ((True, make_p2pd_pair_unix),) ) @pytest.mark.asyncio async def test_client_identify_unix_socket(p2pcs): await p2pcs[0].identify() @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_identify(p2pcs): await p2pcs[0].identify() @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_connect_success(p2pcs): peer_id_0, maddrs_0 = await p2pcs[0].identify() peer_id_1, maddrs_1 = await p2pcs[1].identify() await p2pcs[0].connect(peer_id_1, maddrs_1) # test case: repeated connections await p2pcs[1].connect(peer_id_0, maddrs_0) @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_connect_failure(peer_id_random, p2pcs): peer_id_1, maddrs_1 = await p2pcs[1].identify() await p2pcs[0].identify() # test case: `peer_id` mismatches with pytest.raises(ControlFailure): await p2pcs[0].connect(peer_id_random, maddrs_1) # test case: empty maddrs with pytest.raises(ControlFailure): await p2pcs[0].connect(peer_id_1, []) # test case: wrong maddrs with pytest.raises(ControlFailure): await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")]) async def _check_connection(p2pd_tuple_0, p2pd_tuple_1): peer_id_0, _ = await p2pd_tuple_0.identify() peer_id_1, _ = await p2pd_tuple_1.identify() peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()] peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()] return (peer_id_0 in peers_1) and (peer_id_1 in peers_0) async def connect_safe(p2pd_tuple_0, p2pd_tuple_1): peer_id_1, maddrs_1 = await p2pd_tuple_1.identify() await p2pd_tuple_0.connect(peer_id_1, maddrs_1) await try_until_success( functools.partial( _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1 ) ) @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_connect_safe(p2pcs): await connect_safe(p2pcs[0], p2pcs[1]) @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_list_peers(p2pcs): # test case: no peers assert len(await p2pcs[0].list_peers()) == 0 # test case: 1 peer await connect_safe(p2pcs[0], p2pcs[1]) assert len(await p2pcs[0].list_peers()) == 1 assert len(await p2pcs[1].list_peers()) == 1 # test case: one more peer await connect_safe(p2pcs[0], p2pcs[2]) assert len(await p2pcs[0].list_peers()) == 2 assert len(await p2pcs[1].list_peers()) == 1 assert len(await p2pcs[2].list_peers()) == 1 @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_disconnect(peer_id_random, p2pcs): # test case: disconnect a peer without connections await p2pcs[1].disconnect(peer_id_random) # test case: disconnect peer_id_0, _ = await p2pcs[0].identify() await connect_safe(p2pcs[0], p2pcs[1]) assert len(await p2pcs[0].list_peers()) == 1 assert len(await p2pcs[1].list_peers()) == 1 await p2pcs[1].disconnect(peer_id_0) assert len(await p2pcs[0].list_peers()) == 0 assert len(await p2pcs[1].list_peers()) == 0 # test case: disconnect twice await p2pcs[1].disconnect(peer_id_0) assert len(await p2pcs[0].list_peers()) == 0 assert len(await p2pcs[1].list_peers()) == 0 @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_stream_open_success(p2pcs): peer_id_1, maddrs_1 = await p2pcs[1].identify() await connect_safe(p2pcs[0], p2pcs[1]) proto = "123" async def handle_proto(stream_info, reader, writer): await reader.readexactly(1) await p2pcs[1].stream_handler(proto, handle_proto) # test case: normal stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,)) assert stream_info.peer_id == peer_id_1 assert stream_info.addr in maddrs_1 assert stream_info.proto == "123" writer.close() # test case: open with multiple protocols stream_info, reader, writer = await p2pcs[0].stream_open( peer_id_1, (proto, "another_protocol") ) assert stream_info.peer_id == peer_id_1 assert stream_info.addr in maddrs_1 assert stream_info.proto == "123" writer.close() @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_stream_open_failure(p2pcs): peer_id_1, _ = await p2pcs[1].identify() await connect_safe(p2pcs[0], p2pcs[1]) proto = "123" # test case: `stream_open` to a peer who didn't register the protocol with pytest.raises(ControlFailure): await p2pcs[0].stream_open(peer_id_1, (proto,)) # test case: `stream_open` to a peer for a non-registered protocol async def handle_proto(stream_info, reader, writer): pass await p2pcs[1].stream_handler(proto, handle_proto) with pytest.raises(ControlFailure): await p2pcs[0].stream_open(peer_id_1, ("another_protocol",)) @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_stream_handler_success(p2pcs): peer_id_1, _ = await p2pcs[1].identify() await connect_safe(p2pcs[0], p2pcs[1]) proto = "protocol123" bytes_to_send = b"yoyoyoyoyog" # event for this test function to wait until the handler function receiving the incoming data event_handler_finished = asyncio.Event() async def handle_proto(stream_info, reader, writer): nonlocal event_handler_finished bytes_received = await reader.readexactly(len(bytes_to_send)) assert bytes_received == bytes_to_send event_handler_finished.set() await p2pcs[1].stream_handler(proto, handle_proto) assert proto in p2pcs[1].control.handlers assert handle_proto == p2pcs[1].control.handlers[proto] # test case: test the stream handler `handle_proto` _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,)) # wait until the handler function starts blocking waiting for the data # because we haven't sent the data, we know the handler function must still blocking waiting. # get the task of the protocol handler writer.write(bytes_to_send) # wait for the handler to finish writer.close() await event_handler_finished.wait() # test case: two streams to different handlers respectively another_proto = "another_protocol123" another_bytes_to_send = b"456" event_another_proto = asyncio.Event() async def handle_another_proto(stream_info, reader, writer): event_another_proto.set() bytes_received = await reader.readexactly(len(another_bytes_to_send)) assert bytes_received == another_bytes_to_send await p2pcs[1].stream_handler(another_proto, handle_another_proto) assert another_proto in p2pcs[1].control.handlers assert handle_another_proto == p2pcs[1].control.handlers[another_proto] _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,)) await event_another_proto.wait() # we know at this moment the handler must still blocking wait writer.write(another_bytes_to_send) writer.close() # test case: registering twice can override the previous registration event_third = asyncio.Event() async def handler_third(stream_info, reader, writer): event_third.set() await p2pcs[1].stream_handler(another_proto, handler_third) assert another_proto in p2pcs[1].control.handlers # ensure the handler is override assert handler_third == p2pcs[1].control.handlers[another_proto] await p2pcs[0].stream_open(peer_id_1, (another_proto,)) # ensure the overriding handler is called when the protocol is opened a stream await event_third.wait() @pytest.mark.parametrize("enable_control", (True,)) @pytest.mark.asyncio async def test_client_stream_handler_failure(p2pcs): peer_id_1, _ = await p2pcs[1].identify() await connect_safe(p2pcs[0], p2pcs[1]) proto = "123" # test case: registered a wrong protocol name async def handle_proto_correct_params(stream_info, stream): pass await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params) with pytest.raises(ControlFailure): await p2pcs[0].stream_open(peer_id_1, (proto,))