benchmark_averaging.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import math
  2. import time
  3. import threading
  4. import argparse
  5. import torch
  6. import hivemind
  7. from hivemind.utils import LOCALHOST, increase_file_limit
  8. from hivemind.proto import runtime_pb2
  9. def sample_tensors(hid_size, num_layers):
  10. tensors = []
  11. for i in range(num_layers):
  12. tensors.append(torch.randn(hid_size, 3 * hid_size))
  13. tensors.append(torch.randn(3 * hid_size))
  14. tensors.append(torch.randn(3 * hid_size))
  15. tensors.append(torch.randn(hid_size, hid_size))
  16. tensors.append(torch.ones(hid_size))
  17. tensors.append(torch.zeros(hid_size))
  18. tensors.append(torch.randn(hid_size, 4 * hid_size))
  19. tensors.append(torch.randn(4 * hid_size))
  20. tensors.append(torch.ones(4 * hid_size))
  21. tensors.append(torch.randn(2, hid_size, hid_size, 2))
  22. tensors.append(torch.randn(hid_size))
  23. tensors.append(torch.randn(hid_size))
  24. tensors.append(torch.randn(hid_size))
  25. return tuple(tensors)
  26. def benchmark_averaging(num_peers: int, target_group_size: int, num_rounds: int,
  27. averaging_expiration: float, request_timeout: float, round_timeout: float,
  28. hid_size: int, num_layers: int, spawn_dtime: float):
  29. dht_root = hivemind.DHT(listen_on=f'{LOCALHOST}:*', start=True)
  30. num_groups = 2 ** int(round(math.log2(num_peers / target_group_size)))
  31. nbits = int(round(math.log2(num_groups)))
  32. peer_tensors = [sample_tensors(hid_size, num_layers)
  33. for _ in range(num_peers)]
  34. processes = {dht_root}
  35. def run_averager(index):
  36. dht = hivemind.DHT(listen_on=f'{LOCALHOST}:*',
  37. initial_peers=[f"{LOCALHOST}:{dht_root.port}"],
  38. start=True)
  39. initial_bits = bin(index % num_groups)[2:].rjust(nbits, '0')
  40. averager = hivemind.DecentralizedAverager(
  41. peer_tensors[i], dht, prefix='my_tensor', initial_group_bits=initial_bits, listen_on=f"{LOCALHOST}:*",
  42. compression_type=runtime_pb2.CompressionType.FLOAT16, target_group_size=target_group_size,
  43. averaging_expiration=averaging_expiration, request_timeout=request_timeout, start=True)
  44. processes.update({dht, averager})
  45. print(end=f'<started {index}>\n', flush=True)
  46. for _ in range(num_rounds):
  47. success = averager.step(timeout=round_timeout)
  48. print(end=('+' if success else '-'), flush=True)
  49. print(end=f'<finished {index}>\n', flush=True)
  50. threads = []
  51. for i in range(num_peers):
  52. thread = threading.Thread(target=run_averager, args=[i])
  53. threads.append(thread)
  54. thread.start()
  55. time.sleep(spawn_dtime)
  56. t = time.time()
  57. for thread in threads:
  58. thread.join()
  59. print(f"\ntest run took {time.time() - t:.3f} seconds")
  60. for process in processes:
  61. process.terminate()
  62. if __name__ == "__main__":
  63. parser = argparse.ArgumentParser()
  64. parser.add_argument('--num_peers', type=int, default=16, required=False)
  65. parser.add_argument('--target_group_size', type=int, default=4, required=False)
  66. parser.add_argument('--num_rounds', type=int, default=5, required=False)
  67. parser.add_argument('--hid_size', type=int, default=256, required=False)
  68. parser.add_argument('--num_layers', type=int, default=3, required=False)
  69. parser.add_argument('--averaging_expiration', type=float, default=15, required=False)
  70. parser.add_argument('--round_timeout', type=float, default=30, required=False)
  71. parser.add_argument('--request_timeout', type=float, default=3, required=False)
  72. parser.add_argument('--spawn_dtime', type=float, default=0.1, required=False)
  73. parser.add_argument('--increase_file_limit', action="store_true")
  74. args = vars(parser.parse_args())
  75. if args.pop('increase_file_limit', False):
  76. increase_file_limit()
  77. benchmark_averaging(**args)