test_dht_protocol.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import asyncio
  2. import multiprocessing as mp
  3. import random
  4. import signal
  5. from typing import List, Sequence, Tuple
  6. import pytest
  7. from multiaddr import Multiaddr
  8. import hivemind
  9. from hivemind import P2P, PeerID, get_dht_time, get_logger
  10. from hivemind.dht import DHTID
  11. from hivemind.dht.protocol import DHTProtocol
  12. from hivemind.dht.storage import DictionaryDHTValue
  13. logger = get_logger(__name__)
  14. def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
  15. return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
  16. def run_protocol_listener(
  17. dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
  18. ) -> None:
  19. loop = asyncio.new_event_loop()
  20. asyncio.set_event_loop(loop)
  21. p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
  22. visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
  23. protocol = loop.run_until_complete(
  24. DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
  25. )
  26. logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
  27. for peer_id in maddrs_to_peer_ids(initial_peers):
  28. loop.run_until_complete(protocol.call_ping(peer_id))
  29. maddr_conn.send((p2p.peer_id, visible_maddrs))
  30. async def shutdown():
  31. await p2p.shutdown()
  32. logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
  33. loop.stop()
  34. loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
  35. loop.run_forever()
  36. def launch_protocol_listener(
  37. initial_peers: Sequence[Multiaddr] = (),
  38. ) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
  39. remote_conn, local_conn = mp.Pipe()
  40. dht_id = DHTID.generate()
  41. process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
  42. process.start()
  43. peer_id, visible_maddrs = local_conn.recv()
  44. return dht_id, process, peer_id, visible_maddrs
  45. @pytest.mark.forked
  46. @pytest.mark.asyncio
  47. async def test_dht_protocol():
  48. peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
  49. peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
  50. for client_mode in [True, False]: # note: order matters, this test assumes that first run uses client mode
  51. peer_id = DHTID.generate()
  52. p2p = await P2P.create(initial_peers=peer1_maddrs)
  53. protocol = await DHTProtocol.create(
  54. p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
  55. )
  56. logger.info(f"Self id={protocol.node_id}")
  57. assert peer1_node_id == await protocol.call_ping(peer1_id)
  58. key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
  59. store_ok = await protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  60. assert all(store_ok), "DHT rejected a trivial store"
  61. # peer 1 must know about peer 2
  62. (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [key]))[key]
  63. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  64. (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
  65. assert (
  66. recv_id == peer2_node_id and recv_peer_id == peer2_id
  67. ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
  68. assert recv_value == value and recv_expiration == expiration, (
  69. f"call_find_value expected {value} (expires by {expiration}) "
  70. f"but got {recv_value} (expires by {recv_expiration})"
  71. )
  72. # peer 2 must know about peer 1, but not have a *random* nonexistent value
  73. dummy_key = DHTID.generate()
  74. empty_item, nodes_found_2 = (await protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
  75. assert empty_item is None, "Non-existent keys shouldn't have values"
  76. (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
  77. assert (
  78. recv_id == peer1_node_id and recv_peer_id == peer1_id
  79. ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
  80. # cause a non-response by querying a nonexistent peer
  81. assert not await protocol.call_find(PeerID.from_base58("fakeid"), [key])
  82. # store/get a dictionary with sub-keys
  83. nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
  84. value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
  85. assert await protocol.call_store(
  86. peer1_id,
  87. keys=[nested_key],
  88. values=[hivemind.MSGPackSerializer.dumps(value1)],
  89. expiration_time=[expiration],
  90. subkeys=[subkey1],
  91. )
  92. assert await protocol.call_store(
  93. peer1_id,
  94. keys=[nested_key],
  95. values=[hivemind.MSGPackSerializer.dumps(value2)],
  96. expiration_time=[expiration + 5],
  97. subkeys=[subkey2],
  98. )
  99. (recv_dict, recv_expiration), nodes_found = (await protocol.call_find(peer1_id, [nested_key]))[nested_key]
  100. assert isinstance(recv_dict, DictionaryDHTValue)
  101. assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
  102. assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
  103. assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
  104. if not client_mode:
  105. await p2p.shutdown()
  106. peer1_proc.terminate()
  107. peer2_proc.terminate()
  108. @pytest.mark.forked
  109. @pytest.mark.asyncio
  110. async def test_empty_table():
  111. """Test RPC methods with empty routing table"""
  112. peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
  113. p2p = await P2P.create(initial_peers=peer_maddrs)
  114. protocol = await DHTProtocol.create(
  115. p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
  116. )
  117. key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
  118. empty_item, nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
  119. assert empty_item is None and len(nodes_found) == 0
  120. assert all(await protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration))
  121. (recv_value_bytes, recv_expiration), nodes_found = (await protocol.call_find(peer_peer_id, [key]))[key]
  122. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  123. assert len(nodes_found) == 0
  124. assert recv_value == value and recv_expiration == expiration
  125. assert peer_id == await protocol.call_ping(peer_peer_id)
  126. assert not await protocol.call_ping(PeerID.from_base58("fakeid"))
  127. peer_proc.terminate()