benchmark_averaging.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import argparse
  2. import math
  3. import threading
  4. import time
  5. import torch
  6. import hivemind
  7. from hivemind.proto import runtime_pb2
  8. from hivemind.utils import LOCALHOST, get_logger, increase_file_limit
  9. logger = get_logger(__name__)
  10. def sample_tensors(hid_size, num_layers):
  11. tensors = []
  12. for i in range(num_layers):
  13. tensors.append(torch.randn(hid_size, 3 * hid_size))
  14. tensors.append(torch.randn(3 * hid_size))
  15. tensors.append(torch.randn(3 * hid_size))
  16. tensors.append(torch.randn(hid_size, hid_size))
  17. tensors.append(torch.ones(hid_size))
  18. tensors.append(torch.zeros(hid_size))
  19. tensors.append(torch.randn(hid_size, 4 * hid_size))
  20. tensors.append(torch.randn(4 * hid_size))
  21. tensors.append(torch.ones(4 * hid_size))
  22. tensors.append(torch.randn(2, hid_size, hid_size, 2))
  23. tensors.append(torch.randn(hid_size))
  24. tensors.append(torch.randn(hid_size))
  25. tensors.append(torch.randn(hid_size))
  26. return tuple(tensors)
  27. def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
  28. averaging_expiration: float, request_timeout: float, round_timeout: float,
  29. hid_size: int, num_layers: int, spawn_dtime: float):
  30. dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
  31. num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
  32. nbits = int(round(math.log2(num_groups)))
  33. peer_tensors = [sample_tensors(hid_size, num_layers)
  34. for _ in range(num_peers)]
  35. processes = {dht_root}
  36. lock_stats = threading.Lock()
  37. successful_steps = total_steps = 0
  38. def run_averager(index):
  39. nonlocal successful_steps, total_steps, lock_stats
  40. dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
  41. initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
  42. start=True)
  43. initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
  44. averager = hivemind.averaging.DecentralizedAverager(
  45. peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
  46. compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
  47. averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
  48. processes.update({dht, averager})
  49. logger.info(f'Averager {index}: started on endpoint {averager.endpoint}, group_bits: {averager.get_group_bits()}')
  50. for step in range(num_rounds):
  51. try:
  52. success = averager.step(timeout=round_timeout) is not None
  53. except:
  54. success = False
  55. with lock_stats:
  56. successful_steps += int(success)
  57. total_steps += 1
  58. logger.info(f"Averager {index}: {'finished' if success else 'failed'} step {step}")
  59. logger.info(f"Averager {index}: done.")
  60. threads = []
  61. for i in range(num_peers):
  62. thread = threading.Thread(target=run_averager, args=[i])
  63. threads.append(thread)
  64. thread.start()
  65. time.sleep(spawn_dtime)
  66. t = time.time()
  67. for thread in threads:
  68. thread.join()
  69. logger.info(f"Benchmark finished in {time.time() - t:.3f} seconds.")
  70. logger.info(f"Success rate: {successful_steps / total_steps} ({successful_steps} out of {total_steps} attempts)")
  71. if __name__ == "__main__":
  72. parser = argparse.ArgumentParser()
  73. parser.add_argument('--num_peers', type=int, default=16, required=False)
  74. parser.add_argument('--target_group_size', type=int, default=4, required=False)
  75. parser.add_argument('--num_rounds', type=int, default=5, required=False)
  76. parser.add_argument('--hid_size', type=int, default=256, required=False)
  77. parser.add_argument('--num_layers', type=int, default=3, required=False)
  78. parser.add_argument('--averaging_expiration', type=float, default=5, required=False)
  79. parser.add_argument('--round_timeout', type=float, default=15, required=False)
  80. parser.add_argument('--request_timeout', type=float, default=1, required=False)
  81. parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
  82. parser.add_argument('--increase_file_limit', action="store_true")
  83. args = vars(parser.parse_args())
  84. if args.pop('increase_file_limit', False):
  85. increase_file_limit()
  86. benchmark_averaging(**args)