benchmark_dht.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import argparse
  2. import time
  3. import asyncio
  4. import multiprocessing as mp
  5. import random
  6. import hivemind
  7. from typing import List, Dict
  8. from hivemind import get_dht_time
  9. from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
  10. def run_benchmark_node(node_id, port, peers, ready: mp.Event, request_perod,
  11. expiration_time, wait_before_read, time_to_test, statistics: mp.Queue, dht_loaded: mp.Event):
  12. if asyncio.get_event_loop().is_running():
  13. asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
  14. loop = asyncio.new_event_loop()
  15. asyncio.set_event_loop(loop)
  16. node = DHTNode(node_id, port, initial_peers=peers)
  17. await_forever = hivemind.run_forever(asyncio.get_event_loop().run_forever)
  18. ready.set()
  19. dht_loaded.wait()
  20. start = time.perf_counter()
  21. while time.perf_counter() < start + time_to_test:
  22. query_id = DHTID.generate()
  23. store_value = random.randint(0, 256)
  24. store_time = time.perf_counter()
  25. success_store = asyncio.run_coroutine_threadsafe(
  26. node.store(query_id, store_value, get_dht_time() + expiration_time), loop).result()
  27. store_time = time.perf_counter() - store_time
  28. if success_store:
  29. time.sleep(wait_before_read)
  30. get_time = time.perf_counter()
  31. get_value, get_time_expiration = asyncio.run_coroutine_threadsafe(node.get(query_id), loop).result()
  32. get_time = time.perf_counter() - get_time
  33. success_get = (get_value == store_value)
  34. statistics.put((success_store, store_time, success_get, get_time))
  35. else:
  36. statistics.put((success_store, store_time, None, None))
  37. await_forever.result() # process will exit only if event loop broke down
  38. if __name__ == "__main__":
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument('--num_nodes', type=int, default=20, required=False)
  41. parser.add_argument('--request_perod', type=float, default=2, required=False)
  42. parser.add_argument('--expiration_time', type=float, default=10, required=False)
  43. parser.add_argument('--wait_before_read', type=float, default=1, required=False)
  44. parser.add_argument('--time_to_test', type=float, default=10, required=False)
  45. args = parser.parse_args()
  46. statistics = mp.Queue()
  47. dht: Dict[Endpoint, DHTID] = {}
  48. processes: List[mp.Process] = []
  49. num_nodes = args.num_nodes
  50. request_perod = args.request_perod
  51. expiration_time = args.expiration_time
  52. wait_before_read = args.wait_before_read
  53. time_to_test = args.time_to_test
  54. dht_loaded = mp.Event()
  55. for i in range(num_nodes):
  56. node_id = DHTID.generate()
  57. port = hivemind.find_open_port()
  58. peers = random.sample(dht.keys(), min(len(dht), 5))
  59. ready = mp.Event()
  60. proc = mp.Process(target=run_benchmark_node, args=(node_id, port, peers, ready, request_perod,
  61. expiration_time, wait_before_read, time_to_test, statistics,
  62. dht_loaded), daemon=True)
  63. proc.start()
  64. ready.wait()
  65. processes.append(proc)
  66. dht[(LOCALHOST, port)] = node_id
  67. dht_loaded.set()
  68. time.sleep(time_to_test)
  69. success_store = 0
  70. all_store = 0
  71. time_store = 0
  72. success_get = 0
  73. all_get = 0
  74. time_get = 0
  75. while not statistics.empty():
  76. success_store_i, store_time_i, success_get_i, get_time_i = statistics.get()
  77. all_store += 1
  78. time_store += store_time_i
  79. if success_store_i:
  80. success_store += 1
  81. all_get += 1
  82. success_get += 1 if success_get_i else 0
  83. time_get += get_time_i
  84. alive_nodes_count = 0
  85. loop = asyncio.new_event_loop()
  86. node = DHTNode(loop=loop)
  87. for addr, port in dht:
  88. if loop.run_until_complete(node.protocol.call_ping((addr, port))) is not None:
  89. alive_nodes_count += 1
  90. print("store success rate: ", success_store / all_store)
  91. print("mean store time: ", time_store / all_store)
  92. print("get success rate: ", success_get / all_get)
  93. print("mean get time: ", time_get / all_get)
  94. print("death rate: ", (num_nodes - alive_nodes_count) / num_nodes)