benchmark_averaging.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import argparse
  2. import math
  3. import threading
  4. import time
  5. import torch
  6. import hivemind
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils.limits import increase_file_limit
  9. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  10. from hivemind.utils.networking import LOCALHOST
  11. use_hivemind_log_handler("in_root_logger")
  12. logger = get_logger(__name__)
  13. def sample_tensors(hid_size, num_layers):
  14. tensors = []
  15. for i in range(num_layers):
  16. tensors.append(torch.randn(hid_size, 3 * hid_size))
  17. tensors.append(torch.randn(3 * hid_size))
  18. tensors.append(torch.randn(3 * hid_size))
  19. tensors.append(torch.randn(hid_size, hid_size))
  20. tensors.append(torch.ones(hid_size))
  21. tensors.append(torch.zeros(hid_size))
  22. tensors.append(torch.randn(hid_size, 4 * hid_size))
  23. tensors.append(torch.randn(4 * hid_size))
  24. tensors.append(torch.ones(4 * hid_size))
  25. tensors.append(torch.randn(2, hid_size, hid_size, 2))
  26. tensors.append(torch.randn(hid_size))
  27. tensors.append(torch.randn(hid_size))
  28. tensors.append(torch.randn(hid_size))
  29. return tuple(tensors)
  30. def benchmark_averaging(
  31. num_peers: int,
  32. target_group_size: int,
  33. num_rounds: int,
  34. averaging_expiration: float,
  35. request_timeout: float,
  36. round_timeout: float,
  37. hid_size: int,
  38. num_layers: int,
  39. spawn_dtime: float,
  40. ):
  41. dht_root = hivemind.DHT(start=True)
  42. initial_peers = dht_root.get_visible_maddrs()
  43. num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
  44. nbits = int(round(math.log2(num_groups)))
  45. peer_tensors = [sample_tensors(hid_size, num_layers) for _ in range(num_peers)]
  46. processes = {dht_root}
  47. lock_stats = threading.Lock()
  48. successful_steps = total_steps = 0
  49. def run_averager(index):
  50. nonlocal successful_steps, total_steps, lock_stats
  51. dht = hivemind.DHT(initial_peers=initial_peers, start=True)
  52. initial_bits = bin(index % num_groups)[2:].rjust(nbits, "0")
  53. averager = hivemind.averaging.DecentralizedAverager(
  54. peer_tensors[index],
  55. dht,
  56. prefix="my_tensor",
  57. initial_group_bits=initial_bits,
  58. compression_type=runtime_pb2.CompressionType.FLOAT16,
  59. target_group_size=target_group_size,
  60. averaging_expiration=averaging_expiration,
  61. request_timeout=request_timeout,
  62. start=True,
  63. )
  64. processes.update({dht, averager})
  65. logger.info(
  66. f"Averager {index}: started with peer id {averager.peer_id}, group_bits: {averager.get_group_bits()}"
  67. )
  68. for step in range(num_rounds):
  69. try:
  70. success = averager.step(timeout=round_timeout) is not None
  71. except:
  72. success = False
  73. with lock_stats:
  74. successful_steps += int(success)
  75. total_steps += 1
  76. logger.info(f"Averager {index}: {'finished' if success else 'failed'} step #{step}")
  77. logger.info(f"Averager {index}: done.")
  78. threads = []
  79. for i in range(num_peers):
  80. thread = threading.Thread(target=run_averager, args=[i])
  81. threads.append(thread)
  82. thread.start()
  83. time.sleep(spawn_dtime)
  84. t = time.time()
  85. for thread in threads:
  86. thread.join()
  87. logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
  88. logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")
  89. if __name__ == "__main__":
  90. parser = argparse.ArgumentParser()
  91. parser.add_argument("--num_peers", type=int, default=16, required=False)
  92. parser.add_argument("--target_group_size", type=int, default=4, required=False)
  93. parser.add_argument("--num_rounds", type=int, default=5, required=False)
  94. parser.add_argument("--hid_size", type=int, default=256, required=False)
  95. parser.add_argument("--num_layers", type=int, default=3, required=False)
  96. parser.add_argument("--averaging_expiration", type=float, default=5, required=False)
  97. parser.add_argument("--round_timeout", type=float, default=15, required=False)
  98. parser.add_argument("--request_timeout", type=float, default=1, required=False)
  99. parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)
  100. parser.add_argument("--increase_file_limit", action="store_true")
  101. args = vars(parser.parse_args())
  102. if args.pop("increase_file_limit", False):
  103. increase_file_limit()
  104. benchmark_averaging(**args)