123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- import asyncio
- import multiprocessing as mp
- import random
- import signal
- from typing import List, Sequence, Tuple
- import pytest
- from multiaddr import Multiaddr
- import hivemind
- from hivemind import P2P, PeerID, get_dht_time, get_logger
- from hivemind.dht import DHTID
- from hivemind.dht.protocol import DHTProtocol
- from hivemind.dht.storage import DictionaryDHTValue
- logger = get_logger(__name__)
- def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
- return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
- def run_protocol_listener(
- dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
- ) -> None:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
- visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
- protocol = loop.run_until_complete(
- DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
- )
- logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
- for peer_id in maddrs_to_peer_ids(initial_peers):
- loop.run_until_complete(protocol.call_ping(peer_id))
- maddr_conn.send((p2p.peer_id, visible_maddrs))
- async def shutdown():
- await p2p.shutdown()
- logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
- loop.stop()
- loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
- loop.run_forever()
- def launch_protocol_listener(
- initial_peers: Sequence[Multiaddr] = (),
- ) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
- remote_conn, local_conn = mp.Pipe()
- dht_id = DHTID.generate()
- process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
- process.start()
- peer_id, visible_maddrs = local_conn.recv()
- return dht_id, process, peer_id, visible_maddrs
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dht_protocol():
- peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
- peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
- for client_mode in [True, False]: # note: order matters, this test assumes that first run uses client mode
- peer_id = DHTID.generate()
- p2p = await P2P.create(initial_peers=peer1_maddrs)
- protocol = await DHTProtocol.create(
- p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
- )
- logger.info(f"Self id={protocol.node_id}")
- assert peer1_node_id == await protocol.call_ping(peer1_id)
- key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
- store_ok = await protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
- assert all(store_ok), "DHT rejected a trivial store"
- # peer 1 must know about peer 2
- (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [key]))[key]
- recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
- (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
- assert (
- recv_id == peer2_node_id and recv_peer_id == peer2_id
- ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
- assert recv_value == value and recv_expiration == expiration, (
- f"call_find_value expected {value} (expires by {expiration}) "
- f"but got {recv_value} (expires by {recv_expiration})"
- )
- # peer 2 must know about peer 1, but not have a *random* nonexistent value
- dummy_key = DHTID.generate()
- empty_item, nodes_found_2 = (await protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
- assert empty_item is None, "Non-existent keys shouldn't have values"
- (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
- assert (
- recv_id == peer1_node_id and recv_peer_id == peer1_id
- ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
- # cause a non-response by querying a nonexistent peer
- assert not await protocol.call_find(PeerID.from_base58("fakeid"), [key])
- # store/get a dictionary with sub-keys
- nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
- value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
- assert await protocol.call_store(
- peer1_id,
- keys=[nested_key],
- values=[hivemind.MSGPackSerializer.dumps(value1)],
- expiration_time=[expiration],
- subkeys=[subkey1],
- )
- assert await protocol.call_store(
- peer1_id,
- keys=[nested_key],
- values=[hivemind.MSGPackSerializer.dumps(value2)],
- expiration_time=[expiration + 5],
- subkeys=[subkey2],
- )
- (recv_dict, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [nested_key]))[nested_key]
- assert isinstance(recv_dict, DictionaryDHTValue)
- assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
- assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
- assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
- if not client_mode:
- await p2p.shutdown()
- peer1_proc.terminate()
- peer2_proc.terminate()
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_empty_table():
- """Test RPC methods with empty routing table"""
- peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
- p2p = await P2P.create(initial_peers=peer_maddrs)
- protocol = await DHTProtocol.create(
- p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
- )
- key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
- empty_item, nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
- assert empty_item is None and len(nodes_found) == 0
- assert all(await protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))
- (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
- recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
- assert len(nodes_found) == 0
- assert recv_value == value and recv_expiration == expiration
- assert peer_id == await protocol.call_ping(peer_peer_id)
- assert not await protocol.call_ping(PeerID.from_base58("fakeid"))
- peer_proc.terminate()
|