benchmark_dht.py 4.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import time
  2. import argparse
  3. import random
  4. from typing import Tuple
  5. from warnings import warn
  6. import hivemind
  7. from tqdm import trange
  8. from test_utils import increase_file_limit
  9. def random_endpoint() -> hivemind.Endpoint:
  10. return f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}." \
  11. f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
  12. def benchmark_dht(num_peers: int, initial_peers: int, num_experts: int, expert_batch_size: int, random_seed: int,
  13. wait_after_request: float, wait_before_read: float, wait_timeout: float, expiration_time: float):
  14. old_expiration_time, hivemind.DHT.EXPIRATION = hivemind.DHT.EXPIRATION, expiration_time
  15. random.seed(random_seed)
  16. print("Creating peers...")
  17. peers = []
  18. for _ in trange(num_peers):
  19. neighbors = [f'0.0.0.0:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
  20. peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout, listen_on=f'0.0.0.0:*')
  21. peers.append(peer)
  22. store_peer, get_peer = peers[-2:]
  23. expert_uids = list(set(f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
  24. for _ in range(num_experts)))
  25. print(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
  26. random.shuffle(expert_uids)
  27. print(f"Storing peers to dht in batches of {expert_batch_size}...")
  28. successful_stores = total_stores = total_store_time = 0
  29. benchmark_started = time.perf_counter()
  30. endpoints = []
  31. for start in trange(0, num_experts, expert_batch_size):
  32. store_start = time.perf_counter()
  33. endpoints.append(random_endpoint())
  34. success_list = store_peer.declare_experts(expert_uids[start: start + expert_batch_size], endpoints[-1])
  35. total_store_time += time.perf_counter() - store_start
  36. total_stores += len(success_list)
  37. successful_stores += sum(success_list)
  38. time.sleep(wait_after_request)
  39. print(f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})")
  40. print(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
  41. time.sleep(wait_before_read)
  42. if time.perf_counter() - benchmark_started > expiration_time:
  43. warn("Warning: all keys expired before benchmark started getting them. Consider increasing expiration_time")
  44. successful_gets = total_get_time = 0
  45. for start in trange(0, len(expert_uids), expert_batch_size):
  46. get_start = time.perf_counter()
  47. get_result = get_peer.get_experts(expert_uids[start: start + expert_batch_size])
  48. total_get_time += time.perf_counter() - get_start
  49. for i, expert in enumerate(get_result):
  50. if expert is not None and expert.uid == expert_uids[start + i] \
  51. and expert.endpoint == endpoints[start // expert_batch_size]:
  52. successful_gets += 1
  53. if time.perf_counter() - benchmark_started > expiration_time:
  54. warn("Warning: keys expired midway during get requests. If that is not desired, increase expiration_time param")
  55. print(f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})")
  56. print(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
  57. alive_peers = [peer.is_alive() for peer in peers]
  58. print(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
  59. hivemind.DHT.EXPIRATION = old_expiration_time
  60. if __name__ == "__main__":
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument('--num_peers', type=int, default=32, required=False)
  63. parser.add_argument('--initial_peers', type=int, default=1, required=False)
  64. parser.add_argument('--num_experts', type=int, default=256, required=False)
  65. parser.add_argument('--expert_batch_size', type=int, default=32, required=False)
  66. parser.add_argument('--expiration_time', type=float, default=300, required=False)
  67. parser.add_argument('--wait_after_request', type=float, default=0, required=False)
  68. parser.add_argument('--wait_before_read', type=float, default=0, required=False)
  69. parser.add_argument('--wait_timeout', type=float, default=5, required=False)
  70. parser.add_argument('--random_seed', type=int, default=random.randint(1, 1000))
  71. parser.add_argument('--increase_file_limit', action="store_true")
  72. args = vars(parser.parse_args())
  73. if args.pop('increase_file_limit', False):
  74. increase_file_limit()
  75. benchmark_dht(**args)