benchmark_dht.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import argparse
  2. import random
  3. import time
  4. from tqdm import trange
  5. import hivemind
  6. import hivemind.server.expert_uid
  7. from hivemind.utils.threading import increase_file_limit
  8. logger = hivemind.get_logger(__name__)
  9. def random_endpoint() -> hivemind.Endpoint:
  10. return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
  11. f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
  12. def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
  13. wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration: float):
  14. random.seed(random_seed)
  15. print("Creating peers...")
  16. peers = []
  17. for _ in trange(num_peers):
  18. neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
  19. peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout,
  20. listen_on=f'0.0.0.0:*')
  21. peers.append(peer)
  22. store_peer, get_peer = peers[-2:]
  23. expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
  24. for _ in range(num_experts)))
  25. print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
  26. random.shuffle(expert_uids)
  27. print(f"Storing experts to dht in batches of {expert_batch_size}...")
  28. successful_stores = total_stores = total_store_time = 0
  29. benchmark_started = time.perf_counter()
  30. endpoints = []
  31. for start in trange(0, num_experts, expert_batch_size):
  32. store_start = time.perf_counter()
  33. endpoints.append(random_endpoint())
  34. store_ok = hivemind.declare_experts(store_peer, expert_uids[start: start + expert_batch_size], endpoints[-1],
  35. expiration=expiration)
  36. successes = store_ok.values()
  37. total_store_time += time.perf_counter() - store_start
  38. total_stores += len(successes)
  39. successful_stores += sum(successes)
  40. time.sleep(wait_after_request)
  41. print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
  42. print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
  43. time.sleep(wait_before_read)
  44. if time.perf_counter() - benchmark_started > expiration:
  45. logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
  46. successful_gets = total_get_time = 0
  47. for start in trange(0, len(expert_uids), expert_batch_size):
  48. get_start = time.perf_counter()
  49. get_result = hivemind.get_experts(get_peer, expert_uids[start: start + expert_batch_size])
  50. total_get_time += time.perf_counter() - get_start
  51. for i, expert in enumerate(get_result):
  52. if expert is not None and expert.uid == expert_uids[start + i] \
  53. and expert.endpoint == endpoints[start // expert_batch_size]:
  54. successful_gets += 1
  55. if time.perf_counter() - benchmark_started > expiration:
  56. logger.warning("keys expired midway during get requests. If that isn't desired, increase expiration_time param")
  57. print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
  58. print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
  59. alive_peers = [peer.is_alive() for peer in peers]
  60. print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
  61. if __name__ == "__main__":
  62. parser = argparse.ArgumentParser()
  63. parser.add_argument('--num_peers', type=int, default=32, required=False)
  64. parser.add_argument('--initial_peers', type=int, default=1, required=False)
  65. parser.add_argument('--num_experts', type=int, default=256, required=False)
  66. parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
  67. parser.add_argument('--expiration', type=float, default=300, required=False)
  68. parser.add_argument('--wait_after_request', type=float, default=0, required=False)
  69. parser.add_argument('--wait_before_read', type=float, default=0, required=False)
  70. parser.add_argument('--wait_timeout', type=float, default=5, required=False)
  71. parser.add_argument('--random_seed', type=int, default=random.randint(1, 1000))
  72. parser.add_argument('--increase_file_limit', action="store_true")
  73. args = vars(parser.parse_args())
  74. if args.pop('increase_file_limit', False):
  75. increase_file_limit()
  76. benchmark_dht(**args)