test_dht_node.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. import asyncio
  2. import multiprocessing as mp
  3. import random
  4. import heapq
  5. from typing import Optional
  6. import numpy as np
  7. import hivemind
  8. from typing import List, Dict
  9. from hivemind import get_dht_time
  10. from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
  11. from hivemind.dht.protocol import DHTProtocol
  12. from hivemind.dht.storage import DictionaryDHTValue
  13. def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
  14. loop = asyncio.get_event_loop()
  15. protocol = loop.run_until_complete(DHTProtocol.create(
  16. dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
  17. assert protocol.port == port
  18. print(f"Started peer id={protocol.node_id} port={port}", flush=True)
  19. if ping is not None:
  20. loop.run_until_complete(protocol.call_ping(ping))
  21. started.set()
  22. loop.run_until_complete(protocol.server.wait_for_termination())
  23. print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
  24. def test_dht_protocol():
  25. # create the first peer
  26. peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  27. peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
  28. peer1_proc.start(), peer1_started.wait()
  29. # create another peer that connects to the first peer
  30. peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  31. peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
  32. kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
  33. peer2_proc.start(), peer2_started.wait()
  34. test_success = mp.Event()
  35. def _tester():
  36. # note: we run everything in a separate process to re-initialize all global states from scratch
  37. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  38. loop = asyncio.get_event_loop()
  39. for listen in [False, True]: # note: order matters, this test assumes that first run uses listen=False
  40. protocol = loop.run_until_complete(DHTProtocol.create(
  41. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
  42. print(f"Self id={protocol.node_id}", flush=True)
  43. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
  44. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  45. store_ok = loop.run_until_complete(protocol.call_store(
  46. f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  47. )
  48. assert all(store_ok), "DHT rejected a trivial store"
  49. # peer 1 must know about peer 2
  50. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  51. protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
  52. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  53. (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
  54. assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
  55. f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
  56. assert recv_value == value and recv_expiration == expiration, \
  57. f"call_find_value expected {value} (expires by {expiration}) " \
  58. f"but got {recv_value} (expires by {recv_expiration})"
  59. # peer 2 must know about peer 1, but not have a *random* nonexistent value
  60. dummy_key = DHTID.generate()
  61. recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
  62. protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
  63. assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
  64. (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
  65. assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
  66. f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
  67. # cause a non-response by querying a nonexistent peer
  68. dummy_port = hivemind.find_open_port()
  69. assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
  70. # store/get a dictionary with sub-keys
  71. nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
  72. value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
  73. assert loop.run_until_complete(protocol.call_store(
  74. f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
  75. expiration_time=[expiration], subkeys=[subkey1])
  76. )
  77. assert loop.run_until_complete(protocol.call_store(
  78. f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
  79. expiration_time=[expiration + 5], subkeys=[subkey2])
  80. )
  81. recv_dict, recv_expiration, nodes_found = loop.run_until_complete(
  82. protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
  83. assert isinstance(recv_dict, DictionaryDHTValue)
  84. assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
  85. assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
  86. assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
  87. if listen:
  88. loop.run_until_complete(protocol.shutdown())
  89. print("DHTProtocol test finished successfully!")
  90. test_success.set()
  91. tester = mp.Process(target=_tester, daemon=True)
  92. tester.start()
  93. tester.join()
  94. assert test_success.is_set()
  95. peer1_proc.terminate()
  96. peer2_proc.terminate()
  97. def test_empty_table():
  98. """ Test RPC methods with empty routing table """
  99. peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  100. peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
  101. peer_proc.start(), peer_started.wait()
  102. test_success = mp.Event()
  103. def _tester():
  104. # note: we run everything in a separate process to re-initialize all global states from scratch
  105. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  106. loop = asyncio.get_event_loop()
  107. protocol = loop.run_until_complete(DHTProtocol.create(
  108. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
  109. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  110. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  111. protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
  112. assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
  113. assert all(loop.run_until_complete(protocol.call_store(
  114. f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  115. )), "peer rejected store"
  116. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  117. protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
  118. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  119. assert len(nodes_found) == 0
  120. assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
  121. f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
  122. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
  123. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
  124. test_success.set()
  125. tester = mp.Process(target=_tester, daemon=True)
  126. tester.start()
  127. tester.join()
  128. assert test_success.is_set()
  129. peer_proc.terminate()
  130. def run_node(node_id, peers, status_pipe: mp.Pipe):
  131. if asyncio.get_event_loop().is_running():
  132. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  133. asyncio.set_event_loop(asyncio.new_event_loop())
  134. loop = asyncio.get_event_loop()
  135. node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
  136. status_pipe.send(node.port)
  137. while True:
  138. loop.run_forever()
  139. def test_dht_node():
  140. # create dht with 50 nodes + your 51-st node
  141. dht: Dict[Endpoint, DHTID] = {}
  142. processes: List[mp.Process] = []
  143. for i in range(50):
  144. node_id = DHTID.generate()
  145. peers = random.sample(dht.keys(), min(len(dht), 5))
  146. pipe_recv, pipe_send = mp.Pipe(duplex=False)
  147. proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
  148. proc.start()
  149. port = pipe_recv.recv()
  150. processes.append(proc)
  151. dht[f"{LOCALHOST}:{port}"] = node_id
  152. test_success = mp.Event()
  153. def _tester():
  154. # note: we run everything in a separate process to re-initialize all global states from scratch
  155. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  156. loop = asyncio.get_event_loop()
  157. me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
  158. cache_refresh_before_expiry=False))
  159. # test 1: find self
  160. nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
  161. assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
  162. # test 2: find others
  163. for i in range(10):
  164. ref_endpoint, query_id = random.choice(list(dht.items()))
  165. nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
  166. assert len(nearest) == 1
  167. found_node_id, found_endpoint = next(iter(nearest.items()))
  168. assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
  169. # test 3: find neighbors to random nodes
  170. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  171. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  172. all_node_ids = list(dht.values())
  173. for i in range(100):
  174. query_id = DHTID.generate()
  175. k_nearest = random.randint(1, 20)
  176. exclude_self = random.random() > 0.5
  177. nearest = loop.run_until_complete(
  178. me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
  179. nearest_nodes = list(nearest) # keys from ordered dict
  180. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  181. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
  182. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  183. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  184. if exclude_self and me.node_id in ref_nearest:
  185. ref_nearest.remove(me.node_id)
  186. if len(ref_nearest) > k_nearest:
  187. ref_nearest.pop()
  188. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  189. accuracy_denominator += 1
  190. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  191. jaccard_denominator += k_nearest
  192. accuracy = accuracy_numerator / accuracy_denominator
  193. print("Top-1 accuracy:", accuracy) # should be 98-100%
  194. jaccard_index = jaccard_numerator / jaccard_denominator
  195. print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
  196. assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  197. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  198. # test 4: find all nodes
  199. dummy = DHTID.generate()
  200. nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
  201. assert len(nearest) == len(dht) + 1
  202. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  203. # test 5: node without peers
  204. detached_node = loop.run_until_complete(DHTNode.create())
  205. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
  206. assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
  207. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
  208. assert len(nearest) == 0
  209. # test 6 store and get value
  210. true_time = get_dht_time() + 1200
  211. assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
  212. that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
  213. cache_refresh_before_expiry=False, cache_locally=False))
  214. for node in [me, that_guy]:
  215. val, expiration_time = loop.run_until_complete(node.get("mykey"))
  216. assert val == ["Value", 10], "Wrong value"
  217. assert expiration_time == true_time, f"Wrong time"
  218. assert loop.run_until_complete(detached_node.get("mykey")) == (None, None)
  219. # test 7: bulk store and bulk get
  220. keys = 'foo', 'bar', 'baz', 'zzz'
  221. values = 3, 2, 'batman', [1, 2, 3]
  222. store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
  223. assert all(store_ok.values()), "failed to store one or more keys"
  224. response = loop.run_until_complete(me.get_many(keys[::-1]))
  225. for key, value in zip(keys, values):
  226. assert key in response and response[key][0] == value
  227. # test 8: store dictionaries as values (with sub-keys)
  228. upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
  229. now = get_dht_time()
  230. assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
  231. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
  232. for node in [that_guy, me]:
  233. value, time = loop.run_until_complete(node.get(upper_key))
  234. assert isinstance(value, dict) and time == now + 20
  235. assert value[subkey1] == (123, now + 10)
  236. assert value[subkey2] == (456, now + 20)
  237. assert len(value) == 2
  238. assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
  239. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
  240. assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
  241. loop.run_until_complete(asyncio.sleep(0.1)) # wait for cache to refresh
  242. for node in [that_guy, me]:
  243. value, time = loop.run_until_complete(node.get(upper_key))
  244. assert isinstance(value, dict) and time == now + 50, (value, time)
  245. assert value[subkey1] == (123, now + 10)
  246. assert value[subkey2] == (567, now + 30)
  247. assert value[subkey3] == (890, now + 50)
  248. assert len(value) == 3
  249. test_success.set()
  250. tester = mp.Process(target=_tester, daemon=True)
  251. tester.start()
  252. tester.join()
  253. assert test_success.is_set()
  254. for proc in processes:
  255. proc.terminate()
  256. def test_dhtnode_replicas():
  257. dht_size = 20
  258. initial_peers = 3
  259. num_replicas = random.randint(1, 20)
  260. test_success = mp.Event()
  261. async def _tester():
  262. peers = []
  263. for i in range(dht_size):
  264. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
  265. peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
  266. you = random.choice(peers)
  267. assert await you.store('key1', 'foo', get_dht_time() + 999)
  268. actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
  269. assert num_replicas == actual_key1_replicas
  270. assert await you.store('key2', 'bar', get_dht_time() + 999)
  271. total_size = sum(len(peer.protocol.storage) for peer in peers)
  272. actual_key2_replicas = total_size - actual_key1_replicas
  273. assert num_replicas == actual_key2_replicas
  274. assert await you.store('key2', 'baz', get_dht_time() + 1000)
  275. assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
  276. test_success.set()
  277. proc = mp.Process(target=lambda: asyncio.run(_tester()))
  278. proc.start()
  279. proc.join()
  280. assert test_success.is_set()
  281. def test_dhtnode_caching(T=0.05):
  282. test_success = mp.Event()
  283. async def _tester():
  284. node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
  285. node1 = await hivemind.DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
  286. cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
  287. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  288. await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  289. await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
  290. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  291. assert len(node1.protocol.cache) == 3
  292. assert len(node1.cache_refresh_queue) == 0
  293. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  294. assert len(node1.cache_refresh_queue) == 3
  295. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
  296. await asyncio.sleep(4 * T)
  297. await node1.get('k')
  298. await asyncio.sleep(1 * T)
  299. assert len(node1.protocol.cache) == 3
  300. assert len(node1.cache_refresh_queue) == 2
  301. await asyncio.sleep(3 * T)
  302. assert len(node1.cache_refresh_queue) == 1
  303. await asyncio.sleep(5 * T)
  304. assert len(node1.cache_refresh_queue) == 0
  305. await asyncio.sleep(5 * T)
  306. assert len(node1.cache_refresh_queue) == 0
  307. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
  308. await node1.get('k')
  309. await asyncio.sleep(1 * T)
  310. assert len(node1.cache_refresh_queue) == 0
  311. await node1.get('k')
  312. await asyncio.sleep(1 * T)
  313. assert len(node1.cache_refresh_queue) == 1
  314. await asyncio.sleep(5 * T)
  315. assert len(node1.cache_refresh_queue) == 0
  316. await asyncio.gather(node1.shutdown(), node2.shutdown())
  317. test_success.set()
  318. proc = mp.Process(target=lambda: asyncio.run(_tester()))
  319. proc.start()
  320. proc.join()
  321. assert test_success.is_set()
  322. def test_dhtnode_reuse_get():
  323. test_success = mp.Event()
  324. async def _tester():
  325. peers = []
  326. for i in range(10):
  327. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  328. peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
  329. await asyncio.gather(
  330. random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
  331. random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
  332. )
  333. you = random.choice(peers)
  334. futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
  335. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
  336. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
  337. futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
  338. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
  339. await asyncio.gather(*futures1.values(), *futures2.values())
  340. futures3 = await you.get_many(['k3'], return_futures=True)
  341. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
  342. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
  343. assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
  344. assert (await futures1['k1'])[0] == 123
  345. assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
  346. assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) == (None, None)
  347. test_success.set()
  348. proc = mp.Process(target=lambda: asyncio.run(_tester()))
  349. proc.start()
  350. proc.join()
  351. assert test_success.is_set()