test_dht.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. import time
  2. import asyncio
  3. import multiprocessing as mp
  4. import random
  5. import heapq
  6. import uuid
  7. from itertools import chain
  8. from typing import Optional
  9. import numpy as np
  10. import hivemind
  11. from typing import List, Dict
  12. from hivemind import get_dht_time
  13. from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, DHTProtocol
  14. from hivemind.dht.protocol import LocalStorage
  15. def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
  16. loop = asyncio.get_event_loop()
  17. protocol = loop.run_until_complete(DHTProtocol.create(
  18. dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
  19. assert protocol.port == port
  20. print(f"Started peer id={protocol.node_id} port={port}", flush=True)
  21. if ping is not None:
  22. loop.run_until_complete(protocol.call_ping(ping))
  23. started.set()
  24. loop.run_until_complete(protocol.server.wait_for_termination())
  25. print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
  26. def test_dht_protocol():
  27. # create the first peer
  28. peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  29. peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
  30. peer1_proc.start(), peer1_started.wait()
  31. # create another peer that connects to the first peer
  32. peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  33. peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
  34. kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
  35. peer2_proc.start(), peer2_started.wait()
  36. test_success = mp.Event()
  37. def _tester():
  38. # note: we run everything in a separate process to re-initialize all global states from scratch
  39. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  40. loop = asyncio.get_event_loop()
  41. for listen in [False, True]: # note: order matters, this test assumes that first run uses listen=False
  42. protocol = loop.run_until_complete(DHTProtocol.create(
  43. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
  44. print(f"Self id={protocol.node_id}", flush=True)
  45. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
  46. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  47. store_ok = loop.run_until_complete(protocol.call_store(
  48. f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  49. )
  50. assert all(store_ok), "DHT rejected a trivial store"
  51. # peer 1 must know about peer 2
  52. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  53. protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
  54. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  55. (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
  56. assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
  57. f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
  58. assert recv_value == value and recv_expiration == expiration, \
  59. f"call_find_value expected {value} (expires by {expiration}) " \
  60. f"but got {recv_value} (expires by {recv_expiration})"
  61. # peer 2 must know about peer 1, but not have a *random* nonexistent value
  62. dummy_key = DHTID.generate()
  63. recv_dummy_value, recv_dummy_expiration, nodes_found_2 = loop.run_until_complete(
  64. protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
  65. assert recv_dummy_value is None and recv_dummy_expiration is None, "Non-existent keys shouldn't have values"
  66. (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
  67. assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
  68. f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
  69. # cause a non-response by querying a nonexistent peer
  70. dummy_port = hivemind.find_open_port()
  71. assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
  72. if listen:
  73. loop.run_until_complete(protocol.shutdown())
  74. print("DHTProtocol test finished successfully!")
  75. test_success.set()
  76. tester = mp.Process(target=_tester, daemon=True)
  77. tester.start()
  78. tester.join()
  79. assert test_success.is_set()
  80. peer1_proc.terminate()
  81. peer2_proc.terminate()
  82. def test_empty_table():
  83. """ Test RPC methods with empty routing table """
  84. peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  85. peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
  86. peer_proc.start(), peer_started.wait()
  87. test_success = mp.Event()
  88. def _tester():
  89. # note: we run everything in a separate process to re-initialize all global states from scratch
  90. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  91. loop = asyncio.get_event_loop()
  92. protocol = loop.run_until_complete(DHTProtocol.create(
  93. DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
  94. key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  95. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  96. protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
  97. assert recv_value_bytes is None and recv_expiration is None and len(nodes_found) == 0
  98. assert all(loop.run_until_complete(protocol.call_store(
  99. f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
  100. )), "peer rejected store"
  101. recv_value_bytes, recv_expiration, nodes_found = loop.run_until_complete(
  102. protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
  103. recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
  104. assert len(nodes_found) == 0
  105. assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
  106. f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
  107. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
  108. assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
  109. test_success.set()
  110. tester = mp.Process(target=_tester, daemon=True)
  111. tester.start()
  112. tester.join()
  113. assert test_success.is_set()
  114. peer_proc.terminate()
  115. def run_node(node_id, peers, status_pipe: mp.Pipe):
  116. if asyncio.get_event_loop().is_running():
  117. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  118. asyncio.set_event_loop(asyncio.new_event_loop())
  119. loop = asyncio.get_event_loop()
  120. node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
  121. status_pipe.send(node.port)
  122. while True:
  123. loop.run_forever()
  124. def test_dht_node():
  125. # create dht with 50 nodes + your 51-st node
  126. dht: Dict[Endpoint, DHTID] = {}
  127. processes: List[mp.Process] = []
  128. for i in range(50):
  129. node_id = DHTID.generate()
  130. peers = random.sample(dht.keys(), min(len(dht), 5))
  131. pipe_recv, pipe_send = mp.Pipe(duplex=False)
  132. proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
  133. proc.start()
  134. port = pipe_recv.recv()
  135. processes.append(proc)
  136. dht[f"{LOCALHOST}:{port}"] = node_id
  137. test_success = mp.Event()
  138. def _tester():
  139. # note: we run everything in a separate process to re-initialize all global states from scratch
  140. # this helps us avoid undesirable side-effects when running multiple tests in sequence
  141. loop = asyncio.get_event_loop()
  142. me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10))
  143. # test 1: find self
  144. nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
  145. assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
  146. # test 2: find others
  147. for i in range(10):
  148. ref_endpoint, query_id = random.choice(list(dht.items()))
  149. nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
  150. assert len(nearest) == 1
  151. found_node_id, found_endpoint = next(iter(nearest.items()))
  152. assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
  153. # test 3: find neighbors to random nodes
  154. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  155. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  156. all_node_ids = list(dht.values())
  157. for i in range(100):
  158. query_id = DHTID.generate()
  159. k_nearest = random.randint(1, 20)
  160. exclude_self = random.random() > 0.5
  161. nearest = loop.run_until_complete(
  162. me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self))[query_id]
  163. nearest_nodes = list(nearest) # keys from ordered dict
  164. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  165. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
  166. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  167. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  168. if exclude_self and me.node_id in ref_nearest:
  169. ref_nearest.remove(me.node_id)
  170. if len(ref_nearest) > k_nearest:
  171. ref_nearest.pop()
  172. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  173. accuracy_denominator += 1
  174. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  175. jaccard_denominator += k_nearest
  176. accuracy = accuracy_numerator / accuracy_denominator
  177. print("Top-1 accuracy:", accuracy) # should be 98-100%
  178. jaccard_index = jaccard_numerator / jaccard_denominator
  179. print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
  180. assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  181. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  182. # test 4: find all nodes
  183. dummy = DHTID.generate()
  184. nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
  185. assert len(nearest) == len(dht) + 1
  186. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  187. # test 5: node without peers
  188. other_node = loop.run_until_complete(DHTNode.create())
  189. nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy]))[dummy]
  190. assert len(nearest) == 1 and nearest[other_node.node_id] == f"{LOCALHOST}:{other_node.port}"
  191. nearest = loop.run_until_complete(other_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
  192. assert len(nearest) == 0
  193. # test 6 store and get value
  194. true_time = get_dht_time() + 1200
  195. assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
  196. for node in [me, other_node]:
  197. val, expiration_time = loop.run_until_complete(me.get("mykey"))
  198. assert expiration_time == true_time, "Wrong time"
  199. assert val == ["Value", 10], "Wrong value"
  200. # test 7: bulk store and bulk get
  201. keys = 'foo', 'bar', 'baz', 'zzz'
  202. values = 3, 2, 'batman', [1, 2, 3]
  203. store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
  204. assert all(store_ok.values()), "failed to store one or more keys"
  205. response = loop.run_until_complete(me.get_many(keys[::-1]))
  206. for key, value in zip(keys, values):
  207. assert key in response and response[key][0] == value
  208. test_success.set()
  209. tester = mp.Process(target=_tester, daemon=True)
  210. tester.start()
  211. tester.join()
  212. assert test_success.is_set()
  213. for proc in processes:
  214. proc.terminate()
  215. def test_hivemind_dht():
  216. peers = [hivemind.DHT(start=True)]
  217. for i in range(10):
  218. neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
  219. peers.append(hivemind.DHT(initial_peers=neighbors_i, start=True))
  220. you: hivemind.dht.DHT = random.choice(peers)
  221. theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
  222. expert_uids = [str(uuid.uuid4()) for _ in range(110)]
  223. batch_size = 10
  224. for batch_start in range(0, len(expert_uids), batch_size):
  225. you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
  226. found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
  227. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  228. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  229. that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
  230. theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], f'that_host:{that_guys_port}')
  231. you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
  232. assert isinstance(you_found, hivemind.RemoteExpert)
  233. assert you_found.endpoint == f'that_host:{that_guys_port}'
  234. # test first_k_active
  235. assert list(theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10)) == expert_uids[:10]
  236. some_permuted_experts = random.sample(expert_uids, k=32)
  237. assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32)) == some_permuted_experts
  238. assert list(theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1)) == some_permuted_experts[:1]
  239. fake_and_real_experts = list(chain(*zip(
  240. [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
  241. assert list(theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9)) == some_permuted_experts[:9]
  242. for peer in peers:
  243. peer.shutdown()
  244. def test_dht_single_node():
  245. node = hivemind.DHT(start=True)
  246. assert node.first_k_active(['e3', 'e2'], k=3) == {}
  247. assert node.get_experts(['e3', 'e2']) == [None, None]
  248. assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
  249. for expert in node.get_experts(['e3', 'e2']):
  250. assert expert.endpoint == f"{hivemind.LOCALHOST}:1337"
  251. active_found = node.first_k_active(['e0', 'e1', 'e3', 'e5', 'e2'], k=2)
  252. assert list(active_found.keys()) == ['e1', 'e3']
  253. assert all(expert.uid.startswith(prefix) for prefix, expert in active_found.items())
  254. assert all(node.declare_experts(['e1', 'e2', 'e3'], f"{hivemind.LOCALHOST}:1337"))
  255. def test_first_k_active():
  256. node = hivemind.DHT(start=True)
  257. assert all(node.declare_experts(['e.1.2.3', 'e.1.2.4', 'e.3.4.5'], endpoint=f"{hivemind.LOCALHOST}:1337"))
  258. assert all(node.declare_experts(['e.2.1.1'], endpoint=f"{hivemind.LOCALHOST}:1338"))
  259. results = node.first_k_active(['e.0', 'e.1', 'e.2', 'e.3'], k=2)
  260. assert len(results) == 2 and next(iter(results.keys())) == 'e.1'
  261. assert results['e.1'].uid in ('e.1.2.3', 'e.1.2.4') and results['e.1'].endpoint == f"{hivemind.LOCALHOST}:1337"
  262. assert results['e.2'].uid == 'e.2.1.1' and results['e.2'].endpoint == f"{hivemind.LOCALHOST}:1338"
  263. results = node.first_k_active(['e', 'e.1', 'e.1.2', 'e.1.2.3'], k=10)
  264. assert len(results) == 4
  265. assert 'e' in results
  266. for k in ('e.1', 'e.1.2', 'e.1.2.3'):
  267. assert results[k].uid in ('e.1.2.3', 'e.1.2.4') and results[k].endpoint == f"{hivemind.LOCALHOST}:1337"
  268. def test_store():
  269. d = LocalStorage()
  270. d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.5)
  271. assert d.get(DHTID.generate("key"))[0] == b"val", "Wrong value"
  272. print("Test store passed")
  273. def test_get_expired():
  274. d = LocalStorage()
  275. d.store(DHTID.generate("key"), b"val", get_dht_time() + 0.1)
  276. time.sleep(0.5)
  277. assert d.get(DHTID.generate("key")) == (None, None), "Expired value must be deleted"
  278. print("Test get expired passed")
  279. def test_get_empty():
  280. d = LocalStorage()
  281. assert d.get(DHTID.generate(source="key")) == (None, None), "LocalStorage returned non-existent value"
  282. print("Test get expired passed")
  283. def test_change_expiration_time():
  284. d = LocalStorage()
  285. d.store(DHTID.generate("key"), b"val1", get_dht_time() + 1)
  286. assert d.get(DHTID.generate("key"))[0] == b"val1", "Wrong value"
  287. d.store(DHTID.generate("key"), b"val2", get_dht_time() + 200)
  288. time.sleep(1)
  289. assert d.get(DHTID.generate("key"))[0] == b"val2", "Value must be changed, but still kept in table"
  290. print("Test change expiration time passed")
  291. def test_maxsize_cache():
  292. d = LocalStorage(maxsize=1)
  293. d.store(DHTID.generate("key1"), b"val1", get_dht_time() + 1)
  294. d.store(DHTID.generate("key2"), b"val2", get_dht_time() + 200)
  295. assert d.get(DHTID.generate("key2"))[0] == b"val2", "Value with bigger exp. time must be kept"
  296. assert d.get(DHTID.generate("key1"))[0] is None, "Value with less exp time, must be deleted"