test_dht_node.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435
  1. import asyncio
  2. import heapq
  3. import multiprocessing as mp
  4. import random
  5. import signal
  6. from itertools import product
  7. from typing import List, Sequence, Tuple
  8. import numpy as np
  9. import pytest
  10. from multiaddr import Multiaddr
  11. import hivemind
  12. from hivemind import get_dht_time
  13. from hivemind.dht.node import DHTID, DHTNode
  14. from hivemind.dht.protocol import DHTProtocol
  15. from hivemind.dht.storage import DictionaryDHTValue
  16. from hivemind.p2p import P2P, PeerID
  17. from hivemind.utils.logging import get_logger
  18. from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
  19. logger = get_logger(__name__)
  20. def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
  21. return list({PeerID.from_base58(maddr['p2p']) for maddr in maddrs})
  22. def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
  23. initial_peers: Sequence[Multiaddr]) -> None:
  24. loop = asyncio.get_event_loop()
  25. p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
  26. visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
  27. protocol = loop.run_until_complete(DHTProtocol.create(
  28. p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5))
  29. logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
  30. for endpoint in maddrs_to_peer_ids(initial_peers):
  31. loop.run_until_complete(protocol.call_ping(endpoint))
  32. maddr_conn.send((p2p.id, visible_maddrs))
  33. async def shutdown():
  34. await p2p.shutdown()
  35. logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
  36. loop.stop()
  37. loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
  38. loop.run_forever()
  39. def launch_protocol_listener(initial_peers: Sequence[Multiaddr] = ()) -> \
  40. Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
  41. remote_conn, local_conn = mp.Pipe()
  42. dht_id = DHTID.generate()
  43. process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
  44. process.start()
  45. peer_id, visible_maddrs = local_conn.recv()
  46. return dht_id, process, peer_id, visible_maddrs
  47. # note: we run network-related tests in a separate process to re-initialize all global states from scratch
  48. # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
  49. @pytest.mark.forked
  50. def test_dht_protocol():
  51. peer1_id, peer1_proc, peer1_endpoint, peer1_maddrs = launch_protocol_listener()
  52. peer2_id, peer2_proc, peer2_endpoint, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
  53. loop = asyncio.get_event_loop()
  54. for listen in [False, True]: # note: order matters, this test assumes that first run uses listen=False
  55. p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
  56. protocol = loop.run_until_complete(DHTProtocol.create(
  57. p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
  58. logger.info(f"Self id={protocol.node_id}")
  59. assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
  60. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  61. store_ok = loop.run_until_complete(protocol.call_store(
  62. peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  63. )
  64. assert all(store_ok), "DHT rejected a trivial store"
  65. # peer 1 must know about peer 2
  66. (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
  67. protocol.call_find(peer1_endpoint, [key]))[key]
  68. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  69. (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
  70. assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
  71. f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
  72. assert recv_value == value and recv_expiration == expiration, \
  73. f"call_find_value expected {value} (expires by {expiration}) " \
  74. f"but got {recv_value} (expires by {recv_expiration})"
  75. # peer 2 must know about peer 1, but not have a *random* nonexistent value
  76. dummy_key = DHTID.generate()
  77. empty_item, nodes_found_2 = loop.run_until_complete(
  78. protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
  79. assert empty_item is None, "Non-existent keys shouldn't have values"
  80. (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
  81. assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
  82. f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
  83. # cause a non-response by querying a nonexistent peer
  84. assert loop.run_until_complete(protocol.call_find(PeerID.from_base58('fakeid'), [key])) is None
  85. # store/get a dictionary with sub-keys
  86. nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
  87. value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
  88. assert loop.run_until_complete(protocol.call_store(
  89. peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
  90. expiration_time=[expiration], subkeys=[subkey1])
  91. )
  92. assert loop.run_until_complete(protocol.call_store(
  93. peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
  94. expiration_time=[expiration + 5], subkeys=[subkey2])
  95. )
  96. (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
  97. protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
  98. assert isinstance(recv_dict, DictionaryDHTValue)
  99. assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
  100. assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
  101. assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
  102. if listen:
  103. loop.run_until_complete(p2p.shutdown())
  104. peer1_proc.terminate()
  105. peer2_proc.terminate()
  106. @pytest.mark.forked
  107. def test_empty_table():
  108. """ Test RPC methods with empty routing table """
  109. peer_id, peer_proc, peer_endpoint, peer_maddrs = launch_protocol_listener()
  110. loop = asyncio.get_event_loop()
  111. p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
  112. protocol = loop.run_until_complete(DHTProtocol.create(
  113. p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
  114. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  115. empty_item, nodes_found = loop.run_until_complete(
  116. protocol.call_find(peer_endpoint, [key]))[key]
  117. assert empty_item is None and len(nodes_found) == 0
  118. assert all(loop.run_until_complete(protocol.call_store(
  119. peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  120. )), "peer rejected store"
  121. (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
  122. protocol.call_find(peer_endpoint, [key]))[key]
  123. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  124. assert len(nodes_found) == 0
  125. assert recv_value == value and recv_expiration == expiration
  126. assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
  127. assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58('fakeid'))) is None
  128. peer_proc.terminate()
  129. @pytest.mark.forked
  130. def test_dht_node():
  131. # step A: create a swarm of 50 dht nodes in separate processes
  132. # (first 5 created sequentially, others created in parallel)
  133. processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
  134. # step B: run 51-st node in this process
  135. loop = asyncio.get_event_loop()
  136. initial_peers = random.choice(swarm_maddrs)
  137. me = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
  138. cache_refresh_before_expiry=False))
  139. # test 1: find self
  140. nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
  141. assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
  142. # test 2: find others
  143. for _ in range(10):
  144. ref_endpoint, query_id = random.choice(list(dht.items()))
  145. nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
  146. assert len(nearest) == 1
  147. found_node_id, found_endpoint = next(iter(nearest.items()))
  148. assert found_node_id == query_id and found_endpoint == ref_endpoint
  149. # test 3: find neighbors to random nodes
  150. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  151. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  152. all_node_ids = list(dht.values())
  153. for _ in range(10):
  154. query_id = DHTID.generate()
  155. k_nearest = random.randint(1, 10)
  156. exclude_self = random.random() > 0.5
  157. nearest = loop.run_until_complete(
  158. me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
  159. nearest_nodes = list(nearest) # keys from ordered dict
  160. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  161. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
  162. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  163. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  164. if exclude_self and me.node_id in ref_nearest:
  165. ref_nearest.remove(me.node_id)
  166. if len(ref_nearest) > k_nearest:
  167. ref_nearest.pop()
  168. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  169. accuracy_denominator += 1
  170. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  171. jaccard_denominator += k_nearest
  172. accuracy = accuracy_numerator / accuracy_denominator
  173. logger.debug(f"Top-1 accuracy: {accuracy}") # should be 98-100%
  174. jaccard_index = jaccard_numerator / jaccard_denominator
  175. logger.debug(f"Jaccard index (intersection over union): {jaccard_index}") # should be 95-100%
  176. assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  177. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  178. # test 4: find all nodes
  179. dummy = DHTID.generate()
  180. nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
  181. assert len(nearest) == len(dht) + 1
  182. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  183. # test 5: node without peers
  184. detached_node = loop.run_until_complete(DHTNode.create())
  185. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
  186. assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
  187. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
  188. assert len(nearest) == 0
  189. # test 6: store and get value
  190. true_time = get_dht_time() + 1200
  191. assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
  192. initial_peers = random.choice(swarm_maddrs)
  193. that_guy = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
  194. cache_refresh_before_expiry=False, cache_locally=False))
  195. for node in [me, that_guy]:
  196. val, expiration_time = loop.run_until_complete(node.get("mykey"))
  197. assert val == ["Value", 10], "Wrong value"
  198. assert expiration_time == true_time, f"Wrong time"
  199. assert loop.run_until_complete(detached_node.get("mykey")) is None
  200. # test 7: bulk store and bulk get
  201. keys = 'foo', 'bar', 'baz', 'zzz'
  202. values = 3, 2, 'batman', [1, 2, 3]
  203. store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
  204. assert all(store_ok.values()), "failed to store one or more keys"
  205. response = loop.run_until_complete(me.get_many(keys[::-1]))
  206. for key, value in zip(keys, values):
  207. assert key in response and response[key][0] == value
  208. # test 8: store dictionaries as values (with sub-keys)
  209. upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
  210. now = get_dht_time()
  211. assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
  212. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
  213. for node in [that_guy, me]:
  214. value, time = loop.run_until_complete(node.get(upper_key))
  215. assert isinstance(value, dict) and time == now + 20
  216. assert value[subkey1] == (123, now + 10)
  217. assert value[subkey2] == (456, now + 20)
  218. assert len(value) == 2
  219. assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
  220. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
  221. assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
  222. loop.run_until_complete(asyncio.sleep(0.1)) # wait for cache to refresh
  223. for node in [that_guy, me]:
  224. value, time = loop.run_until_complete(node.get(upper_key))
  225. assert isinstance(value, dict) and time == now + 50, (value, time)
  226. assert value[subkey1] == (123, now + 10)
  227. assert value[subkey2] == (567, now + 30)
  228. assert value[subkey3] == (890, now + 50)
  229. assert len(value) == 3
  230. for proc in processes:
  231. proc.terminate()
  232. # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
  233. loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
  234. @pytest.mark.forked
  235. @pytest.mark.asyncio
  236. async def test_dhtnode_replicas():
  237. num_replicas = random.randint(1, 20)
  238. peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
  239. you = random.choice(peers)
  240. assert await you.store('key1', 'foo', get_dht_time() + 999)
  241. actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
  242. assert num_replicas == actual_key1_replicas
  243. assert await you.store('key2', 'bar', get_dht_time() + 999)
  244. total_size = sum(len(peer.protocol.storage) for peer in peers)
  245. actual_key2_replicas = total_size - actual_key1_replicas
  246. assert num_replicas == actual_key2_replicas
  247. assert await you.store('key2', 'baz', get_dht_time() + 1000)
  248. assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
  249. @pytest.mark.forked
  250. @pytest.mark.asyncio
  251. async def test_dhtnode_caching(T=0.05):
  252. node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
  253. node1 = await DHTNode.create(initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
  254. cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
  255. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  256. await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  257. await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
  258. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  259. assert len(node1.protocol.cache) == 3
  260. assert len(node1.cache_refresh_queue) == 0
  261. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  262. assert len(node1.cache_refresh_queue) == 3
  263. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
  264. await asyncio.sleep(4 * T)
  265. await node1.get('k')
  266. await asyncio.sleep(1 * T)
  267. assert len(node1.protocol.cache) == 3
  268. assert len(node1.cache_refresh_queue) == 2
  269. await asyncio.sleep(3 * T)
  270. assert len(node1.cache_refresh_queue) == 1
  271. await asyncio.sleep(5 * T)
  272. assert len(node1.cache_refresh_queue) == 0
  273. await asyncio.sleep(5 * T)
  274. assert len(node1.cache_refresh_queue) == 0
  275. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
  276. await node1.get('k')
  277. await asyncio.sleep(1 * T)
  278. assert len(node1.cache_refresh_queue) == 0
  279. await node1.get('k')
  280. await asyncio.sleep(1 * T)
  281. assert len(node1.cache_refresh_queue) == 1
  282. await asyncio.sleep(5 * T)
  283. assert len(node1.cache_refresh_queue) == 0
  284. await asyncio.gather(node1.shutdown(), node2.shutdown())
  285. @pytest.mark.forked
  286. @pytest.mark.asyncio
  287. async def test_dhtnode_reuse_get():
  288. peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
  289. await asyncio.gather(
  290. random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
  291. random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
  292. )
  293. you = random.choice(peers)
  294. futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
  295. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
  296. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
  297. futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
  298. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
  299. await asyncio.gather(*futures1.values(), *futures2.values())
  300. futures3 = await you.get_many(['k3'], return_futures=True)
  301. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
  302. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
  303. assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
  304. assert (await futures1['k1'])[0] == 123
  305. assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
  306. assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
  307. @pytest.mark.forked
  308. @pytest.mark.asyncio
  309. async def test_dhtnode_blacklist():
  310. node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
  311. assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
  312. assert len(node2.blacklist.ban_counter) == 0
  313. await asyncio.gather(node3.shutdown(), node4.shutdown())
  314. assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
  315. assert set(node2.blacklist.ban_counter.keys()) == {node3.endpoint, node4.endpoint}
  316. assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
  317. assert node3.endpoint in node1.blacklist
  318. assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
  319. assert node2.endpoint not in node1.blacklist
  320. await asyncio.gather(node1.shutdown(), node2.shutdown())
  321. @pytest.mark.forked
  322. @pytest.mark.asyncio
  323. async def test_dhtnode_edge_cases():
  324. peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
  325. subkeys = [0, '', False, True, 'abyrvalg', 4555]
  326. keys = subkeys + [()]
  327. values = subkeys + [[]]
  328. for key, subkey, value in product(keys, subkeys, values):
  329. await random.choice(peers).store(key=key, subkey=subkey, value=value,
  330. expiration_time=hivemind.get_dht_time() + 999),
  331. stored = await random.choice(peers).get(key=key, latest=True)
  332. assert stored is not None
  333. assert subkey in stored.value
  334. assert stored.value[subkey].value == value
  335. await asyncio.wait([node.shutdown() for node in peers])