test_dht_node.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import asyncio
  2. import heapq
  3. import random
  4. from itertools import product
  5. import numpy as np
  6. import pytest
  7. import hivemind
  8. from hivemind import get_dht_time
  9. from hivemind.dht.node import DHTID, DHTNode
  10. from hivemind.utils.logging import get_logger
  11. from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
  12. logger = get_logger(__name__)
  13. # note: we run network-related tests in a separate process to re-initialize all global states from scratch
  14. # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
  15. @pytest.mark.forked
  16. @pytest.mark.asyncio
  17. async def test_dht_node(
  18. n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
  19. ):
  20. # step A: create a swarm of 50 dht nodes in separate processes
  21. # (first 5 created sequentially, others created in parallel)
  22. processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(
  23. n_peers=n_peers, n_sequential_peers=n_sequential_peers, bucket_size=bucket_size, num_replicas=num_replicas
  24. )
  25. # step B: run 51-st node in this process
  26. initial_peers = random.choice(swarm_maddrs)
  27. me = await DHTNode.create(
  28. initial_peers=initial_peers,
  29. parallel_rpc=parallel_rpc,
  30. bucket_size=bucket_size,
  31. num_replicas=num_replicas,
  32. cache_refresh_before_expiry=False,
  33. )
  34. # test 1: find self
  35. nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
  36. assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
  37. # test 2: find others
  38. for _ in range(10):
  39. ref_peer_id, query_id = random.choice(list(dht.items()))
  40. nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
  41. assert len(nearest) == 1
  42. found_node_id, found_peer_id = next(iter(nearest.items()))
  43. assert found_node_id == query_id and found_peer_id == ref_peer_id
  44. # test 3: find neighbors to random nodes
  45. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  46. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  47. all_node_ids = list(dht.values())
  48. for _ in range(20):
  49. query_id = DHTID.generate()
  50. k_nearest = random.randint(1, 10)
  51. exclude_self = random.random() > 0.5
  52. find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
  53. nearest_nodes = list(find_result[query_id]) # keys from ordered dict
  54. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  55. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
  56. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  57. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  58. if exclude_self and me.node_id in ref_nearest:
  59. ref_nearest.remove(me.node_id)
  60. if len(ref_nearest) > k_nearest:
  61. ref_nearest.pop()
  62. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  63. accuracy_denominator += 1
  64. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  65. jaccard_denominator += k_nearest
  66. accuracy = accuracy_numerator / accuracy_denominator
  67. logger.debug(f"Top-1 accuracy: {accuracy}") # should be 90-100%
  68. jaccard_index = jaccard_numerator / jaccard_denominator
  69. logger.debug(f"Jaccard index (intersection over union): {jaccard_index}") # should be 95-100%
  70. assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  71. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  72. # test 4: find all nodes
  73. dummy = DHTID.generate()
  74. nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
  75. assert len(nearest) == len(dht) + 1
  76. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  77. # test 5: node without peers
  78. detached_node = await DHTNode.create()
  79. nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
  80. assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
  81. nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
  82. assert len(nearest) == 0
  83. # test 6: store and get value
  84. true_time = get_dht_time() + 1200
  85. assert await me.store("mykey", ["Value", 10], true_time)
  86. initial_peers = random.choice(swarm_maddrs)
  87. that_guy = await DHTNode.create(
  88. initial_peers=initial_peers,
  89. parallel_rpc=parallel_rpc,
  90. cache_refresh_before_expiry=False,
  91. cache_locally=False,
  92. )
  93. for node in [me, that_guy]:
  94. val, expiration_time = await node.get("mykey")
  95. assert val == ["Value", 10], "Wrong value"
  96. assert expiration_time == true_time, f"Wrong time"
  97. assert not await detached_node.get("mykey")
  98. # test 7: bulk store and bulk get
  99. keys = "foo", "bar", "baz", "zzz"
  100. values = 3, 2, "batman", [1, 2, 3]
  101. store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
  102. assert all(store_ok.values()), "failed to store one or more keys"
  103. response = await me.get_many(keys[::-1])
  104. for key, value in zip(keys, values):
  105. assert key in response and response[key][0] == value
  106. # test 8: store dictionaries as values (with sub-keys)
  107. upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
  108. now = get_dht_time()
  109. assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
  110. assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
  111. for node in [that_guy, me]:
  112. value, time = await node.get(upper_key)
  113. assert isinstance(value, dict) and time == now + 20
  114. assert value[subkey1] == (123, now + 10)
  115. assert value[subkey2] == (456, now + 20)
  116. assert len(value) == 2
  117. assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
  118. assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
  119. assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
  120. for node in [that_guy, me]:
  121. value, time = await node.get(upper_key, latest=True)
  122. assert isinstance(value, dict) and time == now + 50, (value, time)
  123. assert value[subkey1] == (123, now + 10)
  124. assert value[subkey2] == (567, now + 30)
  125. assert value[subkey3] == (890, now + 50)
  126. assert len(value) == 3
  127. for proc in processes:
  128. proc.terminate()
  129. # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
  130. await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
  131. @pytest.mark.forked
  132. @pytest.mark.asyncio
  133. async def test_dhtnode_replicas():
  134. num_replicas = random.randint(1, 20)
  135. peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
  136. you = random.choice(peers)
  137. assert await you.store("key1", "foo", get_dht_time() + 999)
  138. actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
  139. assert num_replicas == actual_key1_replicas
  140. assert await you.store("key2", "bar", get_dht_time() + 999)
  141. total_size = sum(len(peer.protocol.storage) for peer in peers)
  142. actual_key2_replicas = total_size - actual_key1_replicas
  143. assert num_replicas == actual_key2_replicas
  144. assert await you.store("key2", "baz", get_dht_time() + 1000)
  145. assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
  146. @pytest.mark.forked
  147. @pytest.mark.asyncio
  148. async def test_dhtnode_caching(T=0.05):
  149. node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
  150. node1 = await DHTNode.create(
  151. initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
  152. cache_refresh_before_expiry=5 * T,
  153. client_mode=True,
  154. reuse_get_requests=False,
  155. )
  156. await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
  157. await node2.store("k2", [654, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
  158. await node2.store("k3", [654, "value"], expiration_time=hivemind.get_dht_time() + 15 * T)
  159. await node1.get_many(["k", "k2", "k3", "k4"])
  160. assert len(node1.protocol.cache) == 3
  161. assert len(node1.cache_refresh_queue) == 0
  162. await node1.get_many(["k", "k2", "k3", "k4"])
  163. assert len(node1.cache_refresh_queue) == 3
  164. await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 12 * T)
  165. await asyncio.sleep(4 * T)
  166. await node1.get("k")
  167. await asyncio.sleep(1 * T)
  168. assert len(node1.protocol.cache) == 3
  169. assert len(node1.cache_refresh_queue) == 2
  170. await asyncio.sleep(3 * T)
  171. assert len(node1.cache_refresh_queue) == 1
  172. await asyncio.sleep(5 * T)
  173. assert len(node1.cache_refresh_queue) == 0
  174. await asyncio.sleep(5 * T)
  175. assert len(node1.cache_refresh_queue) == 0
  176. await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 10 * T)
  177. await node1.get("k")
  178. await asyncio.sleep(1 * T)
  179. assert len(node1.cache_refresh_queue) == 0
  180. await node1.get("k")
  181. await asyncio.sleep(1 * T)
  182. assert len(node1.cache_refresh_queue) == 1
  183. await asyncio.sleep(5 * T)
  184. assert len(node1.cache_refresh_queue) == 0
  185. await asyncio.gather(node1.shutdown(), node2.shutdown())
  186. @pytest.mark.forked
  187. @pytest.mark.asyncio
  188. async def test_dhtnode_reuse_get():
  189. peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
  190. await asyncio.gather(
  191. random.choice(peers).store("k1", 123, hivemind.get_dht_time() + 999),
  192. random.choice(peers).store("k2", 567, hivemind.get_dht_time() + 999),
  193. )
  194. you = random.choice(peers)
  195. futures1 = await you.get_many(["k1", "k2"], return_futures=True)
  196. assert len(you.pending_get_requests[DHTID.generate("k1")]) == 1
  197. assert len(you.pending_get_requests[DHTID.generate("k2")]) == 1
  198. futures2 = await you.get_many(["k2", "k3"], return_futures=True)
  199. assert len(you.pending_get_requests[DHTID.generate("k2")]) == 2
  200. await asyncio.gather(*futures1.values(), *futures2.values())
  201. futures3 = await you.get_many(["k3"], return_futures=True)
  202. assert len(you.pending_get_requests[DHTID.generate("k1")]) == 0
  203. assert len(you.pending_get_requests[DHTID.generate("k2")]) == 0
  204. assert len(you.pending_get_requests[DHTID.generate("k3")]) == 1
  205. assert (await futures1["k1"])[0] == 123
  206. assert await futures1["k2"] == await futures2["k2"] and (await futures1["k2"])[0] == 567
  207. assert await futures2["k3"] == await futures3["k3"] and (await futures3["k3"]) is None
  208. @pytest.mark.forked
  209. @pytest.mark.asyncio
  210. async def test_dhtnode_blacklist():
  211. node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
  212. assert await node2.store("abc", 123, expiration_time=hivemind.get_dht_time() + 99)
  213. assert len(node2.blacklist.ban_counter) == 0
  214. await asyncio.gather(node3.shutdown(), node4.shutdown())
  215. assert await node2.store("def", 456, expiration_time=hivemind.get_dht_time() + 99)
  216. assert set(node2.blacklist.ban_counter.keys()) == {node3.peer_id, node4.peer_id}
  217. assert await node1.get("abc", latest=True) # force node1 to crawl dht and discover unresponsive peers
  218. assert node3.peer_id in node1.blacklist
  219. assert await node1.get("abc", latest=True) # force node1 to crawl dht and discover unresponsive peers
  220. assert node2.peer_id not in node1.blacklist
  221. await asyncio.gather(node1.shutdown(), node2.shutdown())
  222. @pytest.mark.forked
  223. @pytest.mark.asyncio
  224. async def test_dhtnode_edge_cases():
  225. peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
  226. subkeys = [0, "", False, True, "abyrvalg", 4555]
  227. keys = subkeys + [()]
  228. values = subkeys + [[]]
  229. for key, subkey, value in product(keys, subkeys, values):
  230. await random.choice(peers).store(
  231. key=key, subkey=subkey, value=value, expiration_time=hivemind.get_dht_time() + 999
  232. ),
  233. stored = await random.choice(peers).get(key=key, latest=True)
  234. assert stored is not None
  235. assert subkey in stored.value
  236. assert stored.value[subkey].value == value
  237. await asyncio.wait([node.shutdown() for node in peers])