test_dht.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import time
  2. import asyncio
  3. import multiprocessing as mp
  4. import random
  5. import heapq
  6. import uuid
  7. from functools import partial
  8. from itertools import chain
  9. from typing import Optional
  10. import numpy as np
  11. import hivemind
  12. from typing import List, Dict
  13. from hivemind import get_dht_time
  14. from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol
  15. from hivemind.dht.protocol import LocalStorage
  16. def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
  17. ping: Optional[hivemind.Endpoint] = None):
  18. loop = asyncio.new_event_loop()
  19. protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5)
  20. listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
  21. transport, protocol = loop.run_until_complete(listen)
  22. print(f"Started peer id={protocol.node_id} port={port}", flush=True)
  23. if ping is not None:
  24. loop.run_until_complete(protocol.call_ping(ping))
  25. started.set()
  26. loop.run_forever()
  27. print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
  28. def test_kademlia_protocol():
  29. try:
  30. # create the first peer
  31. peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  32. peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
  33. peer1_proc.start(), peer1_started.wait()
  34. # create another peer that connects to the first peer
  35. peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
  36. peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
  37. kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True)
  38. peer2_proc.start(), peer2_started.wait()
  39. port = hivemind.find_open_port()
  40. loop = asyncio.new_event_loop()
  41. protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5)
  42. listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
  43. transport, protocol = loop.run_until_complete(listen)
  44. print(f"Self id={protocol.node_id} port={port}", flush=True)
  45. assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id
  46. key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
  47. assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration))
  48. # peer 1 must know about peer 2
  49. nodes_found = loop.run_until_complete(
  50. protocol.call_find_node(('127.0.0.1', peer1_port), key))
  51. (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
  52. assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \
  53. f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}"
  54. # peer 2 must know about peer 1
  55. nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key))
  56. (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
  57. assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \
  58. f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}"
  59. recv_value, recv_expiration, recv_peers = loop.run_until_complete(
  60. protocol.call_find_value(('127.0.0.1', peer1_port), key))
  61. assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
  62. f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
  63. print(recv_peers, nodes_found)
  64. assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node"
  65. print("Kademlia test finished sucessfully!")
  66. finally:
  67. peer1_proc.terminate()
  68. peer2_proc.terminate()
  69. def run_node(node_id, port, peers, status_pipe: mp.Pipe):
  70. if asyncio.get_event_loop().is_running():
  71. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  72. asyncio.set_event_loop(asyncio.new_event_loop())
  73. try:
  74. node = DHTNode(node_id, port, initial_peers=peers)
  75. status_pipe.send('STARTED')
  76. while True:
  77. asyncio.get_event_loop().run_forever()
  78. except BaseException as e:
  79. status_pipe.send(e) # report exception to master
  80. if not isinstance(e, OSError):
  81. raise e
  82. def test_dht():
  83. # create dht with 50 nodes + your 51-st node
  84. dht: Dict[Endpoint, DHTID] = {}
  85. processes: List[mp.Process] = []
  86. port_fails, max_port_fails = 0, 10
  87. while len(dht) < 50:
  88. node_id = DHTID.generate()
  89. peers = random.sample(dht.keys(), min(len(dht), 5))
  90. port = hivemind.find_open_port()
  91. pipe_recv, pipe_send = mp.Pipe(duplex=False)
  92. proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True)
  93. proc.start()
  94. status = pipe_recv.recv()
  95. if status == 'STARTED':
  96. processes.append(proc)
  97. dht[(LOCALHOST, port)] = node_id
  98. else:
  99. assert isinstance(status, BaseException)
  100. proc.terminate()
  101. if isinstance(status, OSError): # port already in use. It just happens sometimes.
  102. port_fails += 1
  103. if port_fails > max_port_fails:
  104. raise OSError("Too many 'Address already in use' errors.")
  105. else:
  106. raise ValueError(f"Failed to create node due to an error {status}, see traceback above")
  107. loop = asyncio.get_event_loop()
  108. me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0) # port=0 means os-specified port
  109. # test 1: find self
  110. nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1))
  111. assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
  112. # test 2: find others
  113. for i in range(10):
  114. ref_endpoint, query_id = random.choice(list(dht.items()))
  115. nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1))
  116. assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
  117. # test 3: find neighbors to random nodes
  118. accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
  119. jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
  120. all_node_ids = list(dht.values())
  121. for i in range(100):
  122. query_id = DHTID.generate()
  123. k_nearest = random.randint(1, 20)
  124. exclude_self = random.random() > 0.5
  125. nearest = loop.run_until_complete(
  126. me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
  127. nearest_nodes = list(nearest) # keys from ordered dict
  128. assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
  129. assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
  130. assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
  131. ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
  132. if exclude_self and me.node_id in ref_nearest:
  133. ref_nearest.remove(me.node_id)
  134. if len(ref_nearest) > k_nearest:
  135. ref_nearest.pop()
  136. accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
  137. accuracy_denominator += 1
  138. jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
  139. jaccard_denominator += k_nearest
  140. accuracy = accuracy_numerator / accuracy_denominator
  141. print("Top-1 accuracy:", accuracy) # should be 98-100%
  142. jaccard_index = jaccard_numerator / jaccard_denominator
  143. print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
  144. assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  145. assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
  146. # test 4: find all nodes
  147. nearest = loop.run_until_complete(
  148. me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100))
  149. assert len(nearest) == len(dht) + 1
  150. assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
  151. # test 5: node without peers
  152. other_node = hivemind.dht.node.DHTNode()
  153. nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
  154. assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
  155. nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
  156. assert len(nearest) == 0
  157. # test 6 store and get value
  158. true_time = get_dht_time() + 1200
  159. assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
  160. val, expiration_time = loop.run_until_complete(me.get("mykey"))
  161. assert expiration_time == true_time, "Wrong time"
  162. assert val == ["Value", 10], "Wrong value"
  163. # terminate remaining processes
  164. for proc in processes:
  165. proc.terminate()
  166. def test_hivemind_dht():
  167. peers = [hivemind.dht.DHT(start=True)]
  168. for i in range(10):
  169. neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))]
  170. peers.append(hivemind.DHT(*neighbors_i, start=True))
  171. you: hivemind.dht.DHT = random.choice(peers)
  172. theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
  173. expert_uids = [str(uuid.uuid4()) for _ in range(110)]
  174. batch_size = 10
  175. for batch_start in range(0, len(expert_uids), batch_size):
  176. you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
  177. found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
  178. assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
  179. assert all(res is None for res in found[-2:]), "Found non-existing experts"
  180. that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
  181. theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port)
  182. you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
  183. assert isinstance(you_found, hivemind.RemoteExpert)
  184. assert you_found.host == 'that_host', you_found.port == that_guys_port
  185. # test first_k_active
  186. assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
  187. some_permuted_experts = random.sample(expert_uids, k=32)
  188. assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts
  189. assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1]
  190. fake_and_real_experts = list(chain(*zip(
  191. [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
  192. assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9]
  193. for peer in peers:
  194. peer.shutdown()
  195. def test_store():
  196. d = LocalStorage()
  197. d.store("key", "val", get_dht_time() + 10)
  198. assert d.get("key")[0] == "val", "Wrong value"
  199. print("Test store passed")
  200. def test_get_expired():
  201. d = LocalStorage()
  202. d.store("key", "val", get_dht_time() + 1)
  203. time.sleep(2)
  204. assert d.get("key") == (None, None), "Expired value must be deleted"
  205. print("Test get expired passed")
  206. def test_get_empty():
  207. d = LocalStorage()
  208. assert d.get("key") == (None, None), "Expired value must be deleted"
  209. print("Test get expired passed")
  210. def test_change_expiration_time():
  211. d = LocalStorage()
  212. d.store("key", "val1", get_dht_time() + 2)
  213. d.store("key", "val2", get_dht_time() + 200)
  214. time.sleep(4)
  215. assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table"
  216. print("Test change expiration time passed")
  217. def test_maxsize_cache():
  218. d = LocalStorage(maxsize=1)
  219. d.store("key1", "val1", get_dht_time() + 1)
  220. d.store("key2", "val2", get_dht_time() + 200)
  221. assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept"
  222. assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"