benchmark_dht.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import argparse
  2. import random
  3. import time
  4. from tqdm import trange
  5. import hivemind
  6. from hivemind.moe.server import declare_experts, get_experts
  7. from hivemind.utils.limits import increase_file_limit
  8. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  9. use_hivemind_log_handler("in_root_logger")
  10. logger = get_logger(__name__)
  11. def random_endpoint() -> hivemind.Endpoint:
  12. return (
  13. f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
  14. f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
  15. )
  16. def benchmark_dht(
  17. num_peers: int,
  18. initial_peers: int,
  19. num_experts: int,
  20. expert_batch_size: int,
  21. random_seed: int,
  22. wait_after_request: float,
  23. wait_before_read: float,
  24. wait_timeout: float,
  25. expiration: float,
  26. ):
  27. random.seed(random_seed)
  28. logger.info("Creating peers...")
  29. peers = []
  30. for _ in trange(num_peers):
  31. neighbors = sum(
  32. [peer.get_visible_maddrs() for peer in random.sample(peers, min(initial_peers, len(peers)))], []
  33. )
  34. peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
  35. peers.append(peer)
  36. store_peer, get_peer = peers[-2:]
  37. expert_uids = list(
  38. set(
  39. f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
  40. for _ in range(num_experts)
  41. )
  42. )
  43. logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
  44. random.shuffle(expert_uids)
  45. logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
  46. successful_stores = total_stores = total_store_time = 0
  47. benchmark_started = time.perf_counter()
  48. endpoints = []
  49. for start in trange(0, num_experts, expert_batch_size):
  50. store_start = time.perf_counter()
  51. endpoints.append(random_endpoint())
  52. store_ok = declare_experts(
  53. store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
  54. )
  55. successes = store_ok.values()
  56. total_store_time += time.perf_counter() - store_start
  57. total_stores += len(successes)
  58. successful_stores += sum(successes)
  59. time.sleep(wait_after_request)
  60. logger.info(
  61. f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
  62. )
  63. logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
  64. time.sleep(wait_before_read)
  65. if time.perf_counter() - benchmark_started > expiration:
  66. logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
  67. successful_gets = total_get_time = 0
  68. for start in trange(0, len(expert_uids), expert_batch_size):
  69. get_start = time.perf_counter()
  70. get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
  71. total_get_time += time.perf_counter() - get_start
  72. for i, expert in enumerate(get_result):
  73. if (
  74. expert is not None
  75. and expert.uid == expert_uids[start + i]
  76. and expert.endpoint == endpoints[start // expert_batch_size]
  77. ):
  78. successful_gets += 1
  79. if time.perf_counter() - benchmark_started > expiration:
  80. logger.warning(
  81. "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
  82. )
  83. logger.info(
  84. f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
  85. )
  86. logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
  87. alive_peers = [peer.is_alive() for peer in peers]
  88. logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
  89. if __name__ == "__main__":
  90. parser = argparse.ArgumentParser()
  91. parser.add_argument("--num_peers", type=int, default=32, required=False)
  92. parser.add_argument("--initial_peers", type=int, default=1, required=False)
  93. parser.add_argument("--num_experts", type=int, default=256, required=False)
  94. parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
  95. parser.add_argument("--expiration", type=float, default=300, required=False)
  96. parser.add_argument("--wait_after_request", type=float, default=0, required=False)
  97. parser.add_argument("--wait_before_read", type=float, default=0, required=False)
  98. parser.add_argument("--wait_timeout", type=float, default=5, required=False)
  99. parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
  100. parser.add_argument("--increase_file_limit", action="store_true")
  101. args = vars(parser.parse_args())
  102. if args.pop("increase_file_limit", False):
  103. increase_file_limit()
  104. benchmark_dht(**args)