test_dht_node.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. import asyncio
  2. import heapq
  3. import multiprocessing as mp
  4. import random
  5. from itertools import product
  6. from typing import Optional, List, Dict
  7. import numpy as np
  8. import pytest
  9. import hivemind
  10. from hivemind import get_dht_time, replace_port
  11. from hivemind.dht.crypto import RSASignatureValidator
  12. from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
  13. from hivemind.dht.protocol import DHTProtocol, ValidationError
  14. from hivemind.dht.storage import DictionaryDHTValue
  15. def run_protocol_listener(port: int, dhtid: DHTID, pipe_side: mp.connection.Connection, ping: Optional[Endpoint] = None):
  16. loop = asyncio.get_event_loop()
  17. protocol = loop.run_until_complete(DHTProtocol.create(
  18. dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
  19. port = protocol.port
  20. print(f"Started peer id={protocol.node_id} port={port}", flush=True)
  21. if ping is not None:
  22. loop.run_until_complete(protocol.call_ping(ping))
  23. pipe_side.send((protocol.port, protocol.server.endpoint))
  24. loop.run_until_complete(protocol.server.wait_for_termination())
  25. print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
  26. # note: we run network-related tests in a separate process to re-initialize all global states from scratch
  27. # this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in a sequence
  28. @pytest.mark.forked
  29. def test_dht_protocol():
  30. # create the first peer
  31. first_side, ours_side = mp.Pipe()
  32. peer1_port, peer1_id = hivemind.find_open_port(), DHTID.generate()
  33. peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, first_side), daemon=True)
  34. peer1_proc.start()
  35. peer1_port, peer1_endpoint = ours_side.recv()
  36. # create another peer that connects to the first peer
  37. second_side, ours_side = mp.Pipe()
  38. peer2_port, peer2_id = hivemind.find_open_port(), DHTID.generate()
  39. peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, second_side),
  40. kwargs={'ping': peer1_endpoint}, daemon=True)
  41. peer2_proc.start()
  42. peer2_port, peer2_endpoint = ours_side.recv()
  43. loop = asyncio.get_event_loop()
  44. for listen in [False, True]: # note: order matters, this test assumes that first run uses listen=False
  45. protocol = loop.run_until_complete(DHTProtocol.create(
  46. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
  47. print(f"Self id={protocol.node_id}", flush=True)
  48. assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
  49. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  50. store_ok = loop.run_until_complete(protocol.call_store(
  51. peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  52. )
  53. assert all(store_ok), "DHT rejected a trivial store"
  54. # peer 1 must know about peer 2
  55. (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
  56. protocol.call_find(peer1_endpoint, [key]))[key]
  57. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  58. (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
  59. assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
  60. f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
  61. assert recv_value == value and recv_expiration == expiration, \
  62. f"call_find_value expected {value} (expires by {expiration}) " \
  63. f"but got {recv_value} (expires by {recv_expiration})"
  64. # peer 2 must know about peer 1, but not have a *random* nonexistent value
  65. dummy_key = DHTID.generate()
  66. empty_item, nodes_found_2 = loop.run_until_complete(
  67. protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
  68. assert empty_item is None, "Non-existent keys shouldn't have values"
  69. (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
  70. assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
  71. f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
  72. # cause a non-response by querying a nonexistent peer
  73. dummy_port = hivemind.find_open_port()
  74. assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
  75. # store/get a dictionary with sub-keys
  76. nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
  77. value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
  78. assert loop.run_until_complete(protocol.call_store(
  79. peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
  80. expiration_time=[expiration], subkeys=[subkey1])
  81. )
  82. assert loop.run_until_complete(protocol.call_store(
  83. peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
  84. expiration_time=[expiration + 5], subkeys=[subkey2])
  85. )
  86. (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
  87. protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
  88. assert isinstance(recv_dict, DictionaryDHTValue)
  89. assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
  90. assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
  91. assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
  92. assert protocol.client.endpoint == loop.run_until_complete(protocol.get_outgoing_request_endpoint(peer1_endpoint))
  93. if listen:
  94. loop.run_until_complete(protocol.shutdown())
  95. peer1_proc.terminate()
  96. peer2_proc.terminate()
  97. loop.run_until_complete(protocol.shutdown())
  98. @pytest.mark.forked
  99. def test_empty_table():
  100. """ Test RPC methods with empty routing table """
  101. theirs_side, ours_side = mp.Pipe()
  102. peer_port, peer_id = hivemind.find_open_port(), DHTID.generate()
  103. peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, theirs_side), daemon=True)
  104. peer_proc.start()
  105. peer_port, peer_endpoint = ours_side.recv()
  106. loop = asyncio.get_event_loop()
  107. protocol = loop.run_until_complete(DHTProtocol.create(
  108. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
  109. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  110. empty_item, nodes_found = loop.run_until_complete(
  111. protocol.call_find(peer_endpoint, [key]))[key]
  112. assert empty_item is None and len(nodes_found) == 0
  113. assert all(loop.run_until_complete(protocol.call_store(
  114. peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  115. )), "peer rejected store"
  116. (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
  117. protocol.call_find(peer_endpoint, [key]))[key]
  118. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  119. assert len(nodes_found) == 0
  120. assert recv_value == value and recv_expiration == expiration
  121. assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
  122. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
  123. peer_proc.terminate()
  124. loop.run_until_complete(protocol.shutdown())
  125. def run_node(node_id, peers, status_pipe: mp.Pipe):
  126. if asyncio.get_event_loop().is_running():
  127. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  128. asyncio.set_event_loop(asyncio.new_event_loop())
  129. loop = asyncio.get_event_loop()
  130. node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
  131. status_pipe.send((node.port, node.endpoint))
  132. loop.run_until_complete(node.protocol.server.wait_for_termination())
  133. @pytest.mark.skip
  134. @pytest.mark.forked
  135. def test_dht_node():
  136. # create dht with 50 nodes + your 51-st node
  137. dht: Dict[Endpoint, DHTID] = {}
  138. processes: List[mp.Process] = []
  139. for i in range(50):
  140. node_id = DHTID.generate()
  141. peers = random.sample(dht.keys(), min(len(dht), 5))
  142. pipe_recv, pipe_send = mp.Pipe(duplex=False)
  143. proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
  144. proc.start()
  145. port, endpoint = pipe_recv.recv()
  146. processes.append(proc)
  147. dht[endpoint] = node_id
  148. loop = asyncio.get_event_loop()
  149. me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), min(len(dht), 5)), parallel_rpc=2,
  150. cache_refresh_before_expiry=False))
  151. # test 1: find self
  152. nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
  153. assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
  154. # test 2: find others
  155. for i in range(10):
  156. ref_endpoint, query_id = random.choice(list(dht.items()))
  157. nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
  158. assert len(nearest) == 1
  159. found_node_id, found_endpoint = next(iter(nearest.items()))
  160. assert found_node_id == query_id and found_endpoint == ref_endpoint
  161. # test 3: find neighbors to random nodes
  162. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  163. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  164. all_node_ids = list(dht.values())
  165. for i in range(50):
  166. query_id = DHTID.generate()
  167. k_nearest = random.randint(1, len(dht))
  168. exclude_self = random.random() > 0.5
  169. nearest = loop.run_until_complete(
  170. me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
  171. nearest_nodes = list(nearest) # keys from ordered dict
  172. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  173. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
  174. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  175. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  176. if exclude_self and me.node_id in ref_nearest:
  177. ref_nearest.remove(me.node_id)
  178. if len(ref_nearest) > k_nearest:
  179. ref_nearest.pop()
  180. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  181. accuracy_denominator += 1
  182. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  183. jaccard_denominator += k_nearest
  184. accuracy = accuracy_numerator / accuracy_denominator
  185. print("Top-1 accuracy:", accuracy) # should be 98-100%
  186. jaccard_index = jaccard_numerator / jaccard_denominator
  187. print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
  188. assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  189. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  190. # test 4: find all nodes
  191. dummy = DHTID.generate()
  192. nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
  193. assert len(nearest) == len(dht) + 1
  194. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  195. # test 5: node without peers
  196. detached_node = loop.run_until_complete(DHTNode.create())
  197. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
  198. assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
  199. nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
  200. assert len(nearest) == 0
  201. # test 6 store and get value
  202. true_time = get_dht_time() + 1200
  203. assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
  204. that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
  205. cache_refresh_before_expiry=False, cache_locally=False))
  206. for node in [me, that_guy]:
  207. val, expiration_time = loop.run_until_complete(node.get("mykey"))
  208. assert val == ["Value", 10], "Wrong value"
  209. assert expiration_time == true_time, f"Wrong time"
  210. assert loop.run_until_complete(detached_node.get("mykey")) is None
  211. # test 7: bulk store and bulk get
  212. keys = 'foo', 'bar', 'baz', 'zzz'
  213. values = 3, 2, 'batman', [1, 2, 3]
  214. store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
  215. assert all(store_ok.values()), "failed to store one or more keys"
  216. response = loop.run_until_complete(me.get_many(keys[::-1]))
  217. for key, value in zip(keys, values):
  218. assert key in response and response[key][0] == value
  219. # test 8: store dictionaries as values (with sub-keys)
  220. upper_key, subkey1, subkey2, subkey3 = 'ololo', 'k1', 'k2', 'k3'
  221. now = get_dht_time()
  222. assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
  223. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
  224. for node in [that_guy, me]:
  225. value, time = loop.run_until_complete(node.get(upper_key))
  226. assert isinstance(value, dict) and time == now + 20
  227. assert value[subkey1] == (123, now + 10)
  228. assert value[subkey2] == (456, now + 20)
  229. assert len(value) == 2
  230. assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
  231. assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
  232. assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
  233. loop.run_until_complete(asyncio.sleep(0.1)) # wait for cache to refresh
  234. for node in [that_guy, me]:
  235. value, time = loop.run_until_complete(node.get(upper_key))
  236. assert isinstance(value, dict) and time == now + 50, (value, time)
  237. assert value[subkey1] == (123, now + 10)
  238. assert value[subkey2] == (567, now + 30)
  239. assert value[subkey3] == (890, now + 50)
  240. assert len(value) == 3
  241. for proc in processes:
  242. proc.terminate()
  243. loop.run_until_complete(asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown()))
  244. @pytest.mark.forked
  245. @pytest.mark.asyncio
  246. async def test_dhtnode_replicas():
  247. dht_size = 20
  248. initial_peers = 3
  249. num_replicas = random.randint(1, 20)
  250. peers = []
  251. for i in range(dht_size):
  252. neighbors_i = [node.endpoint for node in random.sample(peers, min(initial_peers, len(peers)))]
  253. peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
  254. you = random.choice(peers)
  255. assert await you.store('key1', 'foo', get_dht_time() + 999)
  256. actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
  257. assert num_replicas == actual_key1_replicas
  258. assert await you.store('key2', 'bar', get_dht_time() + 999)
  259. total_size = sum(len(peer.protocol.storage) for peer in peers)
  260. actual_key2_replicas = total_size - actual_key1_replicas
  261. assert num_replicas == actual_key2_replicas
  262. assert await you.store('key2', 'baz', get_dht_time() + 1000)
  263. assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
  264. for p in peers:
  265. await p.shutdown()
  266. @pytest.mark.skip ## fails stochastically
  267. @pytest.mark.forked
  268. @pytest.mark.asyncio
  269. async def test_dhtnode_caching(T=0.2):
  270. node2 = await hivemind.DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
  271. node1 = await hivemind.DHTNode.create(initial_peers=[node2.endpoint],
  272. cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
  273. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  274. await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
  275. await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
  276. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  277. assert len(node1.protocol.cache) == 3
  278. assert len(node1.cache_refresh_queue) == 0
  279. await node1.get_many(['k', 'k2', 'k3', 'k4'])
  280. assert len(node1.cache_refresh_queue) == 3
  281. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 12 * T)
  282. await asyncio.sleep(4 * T)
  283. await node1.get('k')
  284. await asyncio.sleep(1 * T)
  285. assert len(node1.protocol.cache) == 3
  286. assert len(node1.cache_refresh_queue) == 2
  287. await asyncio.sleep(3 * T)
  288. assert len(node1.cache_refresh_queue) == 1
  289. await asyncio.sleep(5 * T)
  290. assert len(node1.cache_refresh_queue) == 0
  291. await asyncio.sleep(5 * T)
  292. assert len(node1.cache_refresh_queue) == 0
  293. await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 10 * T)
  294. await node1.get('k')
  295. await asyncio.sleep(1 * T)
  296. assert len(node1.cache_refresh_queue) == 0
  297. await node1.get('k')
  298. await asyncio.sleep(1 * T)
  299. assert len(node1.cache_refresh_queue) == 1
  300. await asyncio.sleep(5 * T)
  301. assert len(node1.cache_refresh_queue) == 0
  302. await asyncio.gather(node1.shutdown(), node2.shutdown())
  303. @pytest.mark.skip # hangs stochastically
  304. @pytest.mark.forked
  305. @pytest.mark.asyncio
  306. async def test_dhtnode_reuse_get():
  307. peers = []
  308. for i in range(5):
  309. neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
  310. peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=32))
  311. await asyncio.gather(
  312. random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
  313. random.choice(peers).store('k2', 567, hivemind.get_dht_time() + 999)
  314. )
  315. you = random.choice(peers)
  316. futures1 = await you.get_many(['k1', 'k2'], return_futures=True)
  317. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 1
  318. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 1
  319. futures2 = await you.get_many(['k2', 'k3'], return_futures=True)
  320. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 2
  321. await asyncio.gather(*futures1.values(), *futures2.values())
  322. futures3 = await you.get_many(['k3'], return_futures=True)
  323. assert len(you.pending_get_requests[DHTID.generate('k1')]) == 0
  324. assert len(you.pending_get_requests[DHTID.generate('k2')]) == 0
  325. assert len(you.pending_get_requests[DHTID.generate('k3')]) == 1
  326. assert (await futures1['k1'])[0] == 123
  327. assert await futures1['k2'] == await futures2['k2'] and (await futures1['k2'])[0] == 567
  328. assert await futures2['k3'] == await futures3['k3'] and (await futures3['k3']) is None
  329. for p in peers:
  330. await p.shutdown()
  331. @pytest.mark.skip # fails stochastically
  332. @pytest.mark.forked
  333. @pytest.mark.asyncio
  334. async def test_dhtnode_blacklist():
  335. node1 = await hivemind.DHTNode.create(blacklist_time=999)
  336. node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
  337. node3 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
  338. node4 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint])
  339. assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
  340. assert len(node2.blacklist.ban_counter) == 0
  341. await node3.shutdown()
  342. await node4.shutdown()
  343. assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
  344. assert len(node2.blacklist.ban_counter) == 2
  345. for banned_peer in node2.blacklist.ban_counter:
  346. assert any(banned_peer == endpoint for endpoint in [node3.endpoint, node4.endpoint])
  347. node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(node1.endpoint)
  348. assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
  349. assert node3_endpoint in node1.blacklist
  350. node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(node1.endpoint)
  351. assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
  352. assert node2_endpoint not in node1.blacklist
  353. for node in [node1, node2, node3, node4]:
  354. await node.shutdown()
  355. @pytest.mark.forked
  356. @pytest.mark.asyncio
  357. async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
  358. node1 = await hivemind.DHTNode.create(blacklist_time=999)
  359. with pytest.raises(ValidationError):
  360. node2 = await hivemind.DHTNode.create(blacklist_time=999, initial_peers=[node1.endpoint],
  361. endpoint=fake_endpoint)
  362. @pytest.mark.skip # takes too long, was never converted
  363. @pytest.mark.forked
  364. @pytest.mark.asyncio
  365. async def test_dhtnode_edge_cases():
  366. peers = []
  367. for i in range(5):
  368. neighbors_i = [node.endpoint for node in random.sample(peers, min(3, len(peers)))]
  369. peers.append(await hivemind.DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
  370. subkeys = [0, '', False, True, 'abyrvalg', 4555]
  371. keys = subkeys + [()]
  372. values = subkeys + [[]]
  373. for key, subkey, value in product(keys, subkeys, values):
  374. await random.choice(peers).store(key=key, subkey=subkey, value=value,
  375. expiration_time=hivemind.get_dht_time() + 999),
  376. stored = await random.choice(peers).get(key=key, latest=True)
  377. assert stored is not None
  378. assert subkey in stored.value
  379. assert stored.value[subkey].value == value
  380. @pytest.mark.forked
  381. @pytest.mark.asyncio
  382. async def test_dhtnode_signatures():
  383. alice = await hivemind.DHTNode.create(record_validator=RSASignatureValidator())
  384. bob = await hivemind.DHTNode.create(
  385. record_validator=RSASignatureValidator(), initial_peers=[alice.endpoint])
  386. mallory = await hivemind.DHTNode.create(
  387. record_validator=RSASignatureValidator(), initial_peers=[alice.endpoint])
  388. key = b'key'
  389. subkey = b'protected_subkey' + bob.protocol.record_validator.ownership_marker
  390. assert await bob.store(key, b'true_value', hivemind.get_dht_time() + 10, subkey=subkey)
  391. assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
  392. store_ok = await mallory.store(key, b'fake_value', hivemind.get_dht_time() + 10, subkey=subkey)
  393. assert not store_ok
  394. assert (await alice.get(key, latest=True)).value[subkey].value == b'true_value'
  395. assert await bob.store(key, b'updated_true_value', hivemind.get_dht_time() + 10, subkey=subkey)
  396. assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'
  397. await bob.shutdown() # Bob has shut down, now Mallory is the single peer of Alice
  398. store_ok = await mallory.store(key, b'updated_fake_value',
  399. hivemind.get_dht_time() + 10, subkey=subkey)
  400. assert not store_ok
  401. assert (await alice.get(key, latest=True)).value[subkey].value == b'updated_true_value'