benchmark_dht.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import argparse
  2. import asyncio
  3. import random
  4. import time
  5. import uuid
  6. from typing import Tuple
  7. import numpy as np
  8. from tqdm import trange
  9. import hivemind
  10. from hivemind.utils.limits import increase_file_limit
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. use_hivemind_log_handler("in_root_logger")
  13. logger = get_logger(__name__)
  14. class NodeKiller:
  15. """Auxiliary class that kills dht nodes over a pre-defined schedule"""
  16. def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
  17. self.shutdown_peers = set(shutdown_peers)
  18. self.shutdown_timestamps = shutdown_timestamps
  19. self.current_iter = 0
  20. self.timestamp_iter = 0
  21. self.lock = asyncio.Lock()
  22. async def check_and_kill(self):
  23. async with self.lock:
  24. if (
  25. self.shutdown_timestamps != None
  26. and self.timestamp_iter < len(self.shutdown_timestamps)
  27. and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
  28. ):
  29. shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
  30. shutdown_peer.shutdown()
  31. self.shutdown_peers.remove(shutdown_peer)
  32. self.timestamp_iter += 1
  33. self.current_iter += 1
  34. async def store_and_get_task(
  35. peers: list,
  36. total_num_rounds: int,
  37. num_store_peers: int,
  38. num_get_peers: int,
  39. wait_after_iteration: float,
  40. delay: float,
  41. expiration: float,
  42. latest: bool,
  43. node_killer: NodeKiller,
  44. ) -> Tuple[list, list, list, list, int, int]:
  45. """Iteratively choose random peers to store data onto the dht, then retrieve with another random subset of peers"""
  46. total_stores = total_gets = 0
  47. successful_stores = []
  48. successful_gets = []
  49. store_times = []
  50. get_times = []
  51. for _ in range(total_num_rounds):
  52. key = uuid.uuid4().hex
  53. store_start = time.perf_counter()
  54. store_peers = random.sample(peers, min(num_store_peers, len(peers)))
  55. store_subkeys = [uuid.uuid4().hex for _ in store_peers]
  56. store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
  57. store_tasks = [
  58. peer.store(
  59. key,
  60. subkey=subkey,
  61. value=store_values[subkey],
  62. expiration_time=hivemind.get_dht_time() + expiration,
  63. return_future=True,
  64. )
  65. for peer, subkey in zip(store_peers, store_subkeys)
  66. ]
  67. store_result = await asyncio.gather(*store_tasks)
  68. await node_killer.check_and_kill()
  69. store_times.append(time.perf_counter() - store_start)
  70. total_stores += len(store_result)
  71. successful_stores_per_iter = sum(store_result)
  72. successful_stores.append(successful_stores_per_iter)
  73. await asyncio.sleep(delay)
  74. get_start = time.perf_counter()
  75. get_peers = random.sample(peers, min(num_get_peers, len(peers)))
  76. get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
  77. get_result = await asyncio.gather(*get_tasks)
  78. get_times.append(time.perf_counter() - get_start)
  79. successful_gets_per_iter = 0
  80. total_gets += len(get_result)
  81. for result in get_result:
  82. if result != None:
  83. attendees, expiration = result
  84. if len(attendees.keys()) == successful_stores_per_iter:
  85. get_ok = True
  86. for key in attendees:
  87. if attendees[key][0] != store_values[key]:
  88. get_ok = False
  89. break
  90. successful_gets_per_iter += get_ok
  91. successful_gets.append(successful_gets_per_iter)
  92. await asyncio.sleep(wait_after_iteration)
  93. return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
  94. async def benchmark_dht(
  95. num_peers: int,
  96. initial_peers: int,
  97. random_seed: int,
  98. num_threads: int,
  99. total_num_rounds: int,
  100. num_store_peers: int,
  101. num_get_peers: int,
  102. wait_after_iteration: float,
  103. delay: float,
  104. wait_timeout: float,
  105. expiration: float,
  106. latest: bool,
  107. failure_rate: float,
  108. ):
  109. random.seed(random_seed)
  110. logger.info("Creating peers...")
  111. peers = []
  112. for _ in trange(num_peers):
  113. neighbors = sum(
  114. [peer.get_visible_maddrs() for peer in random.sample(peers, min(initial_peers, len(peers)))], []
  115. )
  116. peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
  117. peers.append(peer)
  118. benchmark_started = time.perf_counter()
  119. logger.info("Creating store and get tasks...")
  120. shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
  121. assert len(shutdown_peers) != len(peers)
  122. remaining_peers = list(set(peers) - set(shutdown_peers))
  123. shutdown_timestamps = random.sample(
  124. range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
  125. )
  126. shutdown_timestamps.sort()
  127. node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
  128. task_list = [
  129. asyncio.create_task(
  130. store_and_get_task(
  131. remaining_peers,
  132. total_num_rounds,
  133. num_store_peers,
  134. num_get_peers,
  135. wait_after_iteration,
  136. delay,
  137. expiration,
  138. latest,
  139. node_killer,
  140. )
  141. )
  142. for _ in trange(num_threads)
  143. ]
  144. store_and_get_result = await asyncio.gather(*task_list)
  145. benchmark_total_time = time.perf_counter() - benchmark_started
  146. total_store_times = []
  147. total_get_times = []
  148. total_successful_stores = []
  149. total_successful_gets = []
  150. total_stores = total_gets = 0
  151. for result in store_and_get_result:
  152. store_times, get_times, successful_stores, successful_gets, stores, gets = result
  153. total_store_times.extend(store_times)
  154. total_get_times.extend(get_times)
  155. total_successful_stores.extend(successful_stores)
  156. total_successful_gets.extend(successful_gets)
  157. total_stores += stores
  158. total_gets += gets
  159. alive_peers = [peer.is_alive() for peer in peers]
  160. logger.info(
  161. f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
  162. + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
  163. )
  164. logger.info(
  165. f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
  166. + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
  167. )
  168. logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
  169. logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
  170. logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
  171. logger.info(
  172. "Store success rate: "
  173. + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
  174. )
  175. logger.info(
  176. "Get success rate: "
  177. + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
  178. )
  179. logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
  180. if __name__ == "__main__":
  181. parser = argparse.ArgumentParser()
  182. parser.add_argument("--num_peers", type=int, default=16, required=False)
  183. parser.add_argument("--initial_peers", type=int, default=4, required=False)
  184. parser.add_argument("--random_seed", type=int, default=30, required=False)
  185. parser.add_argument("--num_threads", type=int, default=10, required=False)
  186. parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
  187. parser.add_argument("--num_store_peers", type=int, default=8, required=False)
  188. parser.add_argument("--num_get_peers", type=int, default=8, required=False)
  189. parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
  190. parser.add_argument("--delay", type=float, default=0, required=False)
  191. parser.add_argument("--wait_timeout", type=float, default=5, required=False)
  192. parser.add_argument("--expiration", type=float, default=300, required=False)
  193. parser.add_argument("--latest", type=bool, default=True, required=False)
  194. parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
  195. parser.add_argument("--increase_file_limit", action="store_true")
  196. args = vars(parser.parse_args())
  197. if args.pop("increase_file_limit", False):
  198. increase_file_limit()
  199. asyncio.run(benchmark_dht(**args))