benchmark_dht.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import argparse
  2. import asyncio
  3. import random
  4. import time
  5. import uuid
  6. from logging import shutdown
  7. from typing import Tuple
  8. import numpy as np
  9. from tqdm import trange
  10. import hivemind
  11. from hivemind.utils.limits import increase_file_limit
  12. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger(__name__)
  15. class NodeKiller:
  16. """Auxiliary class that kills dht nodes over a pre-defined schedule"""
  17. def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
  18. self.shutdown_peers = set(shutdown_peers)
  19. self.shutdown_timestamps = shutdown_timestamps
  20. self.current_iter = 0
  21. self.timestamp_iter = 0
  22. self.lock = asyncio.Lock()
  23. async def check_and_kill(self):
  24. async with self.lock:
  25. if (
  26. self.shutdown_timestamps != None
  27. and self.timestamp_iter < len(self.shutdown_timestamps)
  28. and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
  29. ):
  30. shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
  31. shutdown_peer.shutdown()
  32. self.shutdown_peers.remove(shutdown_peer)
  33. self.timestamp_iter += 1
  34. self.current_iter += 1
  35. async def store_and_get_task(
  36. peers: list,
  37. total_num_rounds: int,
  38. num_store_peers: int,
  39. num_get_peers: int,
  40. wait_after_iteration: float,
  41. delay: float,
  42. expiration: float,
  43. latest: bool,
  44. node_killer: NodeKiller,
  45. ) -> Tuple[list, list, list, list, int, int]:
  46. """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
  47. total_stores = total_gets = 0
  48. successful_stores = []
  49. successful_gets = []
  50. store_times = []
  51. get_times = []
  52. for _ in range(total_num_rounds):
  53. key = uuid.uuid4().hex
  54. store_start = time.perf_counter()
  55. store_peers = random.sample(peers, min(num_store_peers, len(peers)))
  56. store_subkeys = [uuid.uuid4().hex for _ in store_peers]
  57. store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
  58. store_tasks = [
  59. peer.store(
  60. key,
  61. subkey=subkey,
  62. value=store_values[subkey],
  63. expiration_time=hivemind.get_dht_time() + expiration,
  64. return_future=True,
  65. )
  66. for peer, subkey in zip(store_peers, store_subkeys)
  67. ]
  68. store_result = await asyncio.gather(*store_tasks)
  69. await node_killer.check_and_kill()
  70. store_times.append(time.perf_counter() - store_start)
  71. total_stores += len(store_result)
  72. successful_stores_per_iter = sum(store_result)
  73. successful_stores.append(successful_stores_per_iter)
  74. await asyncio.sleep(delay)
  75. get_start = time.perf_counter()
  76. get_peers = random.sample(peers, min(num_get_peers, len(peers)))
  77. get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
  78. get_result = await asyncio.gather(*get_tasks)
  79. get_times.append(time.perf_counter() - get_start)
  80. successful_gets_per_iter = 0
  81. total_gets += len(get_result)
  82. for result in get_result:
  83. if result != None:
  84. attendees, expiration = result
  85. if len(attendees.keys()) == successful_stores_per_iter:
  86. get_ok = True
  87. for key in attendees:
  88. if attendees[key][0] != store_values[key]:
  89. get_ok = False
  90. break
  91. successful_gets_per_iter += get_ok
  92. successful_gets.append(successful_gets_per_iter)
  93. await asyncio.sleep(wait_after_iteration)
  94. return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
  95. async def benchmark_dht(
  96. num_peers: int,
  97. initial_peers: int,
  98. random_seed: int,
  99. num_threads: int,
  100. total_num_rounds: int,
  101. num_store_peers: int,
  102. num_get_peers: int,
  103. wait_after_iteration: float,
  104. delay: float,
  105. wait_timeout: float,
  106. expiration: float,
  107. latest: bool,
  108. failure_rate: float,
  109. ):
  110. random.seed(random_seed)
  111. logger.info("Creating peers...")
  112. peers = []
  113. for _ in trange(num_peers):
  114. neighbors = sum(
  115. [peer.get_visible_maddrs() for peer in random.sample(peers, min(initial_peers, len(peers)))], []
  116. )
  117. peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
  118. peers.append(peer)
  119. benchmark_started = time.perf_counter()
  120. logger.info("Creating store and get tasks...")
  121. shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
  122. assert len(shutdown_peers) != len(peers)
  123. remaining_peers = list(set(peers) - set(shutdown_peers))
  124. shutdown_timestamps = random.sample(
  125. range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
  126. )
  127. shutdown_timestamps.sort()
  128. node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
  129. task_list = [
  130. asyncio.create_task(
  131. store_and_get_task(
  132. remaining_peers,
  133. total_num_rounds,
  134. num_store_peers,
  135. num_get_peers,
  136. wait_after_iteration,
  137. delay,
  138. expiration,
  139. latest,
  140. node_killer,
  141. )
  142. )
  143. for _ in trange(num_threads)
  144. ]
  145. store_and_get_result = await asyncio.gather(*task_list)
  146. benchmark_total_time = time.perf_counter() - benchmark_started
  147. total_store_times = []
  148. total_get_times = []
  149. total_successful_stores = []
  150. total_successful_gets = []
  151. total_stores = total_gets = 0
  152. for result in store_and_get_result:
  153. store_times, get_times, successful_stores, successful_gets, stores, gets = result
  154. total_store_times.extend(store_times)
  155. total_get_times.extend(get_times)
  156. total_successful_stores.extend(successful_stores)
  157. total_successful_gets.extend(successful_gets)
  158. total_stores += stores
  159. total_gets += gets
  160. alive_peers = [peer.is_alive() for peer in peers]
  161. logger.info(
  162. f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
  163. + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
  164. )
  165. logger.info(
  166. f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
  167. + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
  168. )
  169. logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
  170. logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
  171. logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
  172. logger.info(
  173. "Store success rate: "
  174. + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
  175. )
  176. logger.info(
  177. "Get success rate: "
  178. + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
  179. )
  180. logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
  181. if __name__ == "__main__":
  182. parser = argparse.ArgumentParser()
  183. parser.add_argument("--num_peers", type=int, default=16, required=False)
  184. parser.add_argument("--initial_peers", type=int, default=4, required=False)
  185. parser.add_argument("--random_seed", type=int, default=30, required=False)
  186. parser.add_argument("--num_threads", type=int, default=10, required=False)
  187. parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
  188. parser.add_argument("--num_store_peers", type=int, default=8, required=False)
  189. parser.add_argument("--num_get_peers", type=int, default=8, required=False)
  190. parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
  191. parser.add_argument("--delay", type=float, default=0, required=False)
  192. parser.add_argument("--wait_timeout", type=float, default=5, required=False)
  193. parser.add_argument("--expiration", type=float, default=300, required=False)
  194. parser.add_argument("--latest", type=bool, default=True, required=False)
  195. parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
  196. parser.add_argument("--increase_file_limit", action="store_true")
  197. args = vars(parser.parse_args())
  198. if args.pop("increase_file_limit", False):
  199. increase_file_limit()
  200. asyncio.run(benchmark_dht(**args))