benchmark_throughput.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import argparse
  2. import multiprocessing as mp
  3. import random
  4. import sys
  5. import time
  6. import torch
  7. import hivemind
  8. from hivemind import find_open_port
  9. from hivemind.server import layers
  10. from hivemind.utils.threading import increase_file_limit
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger(__name__)
  13. def print_device_info(device=None):
  14. """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
  15. device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
  16. logger.info(f'Using device: {device}')
  17. # Additional Info when using cuda
  18. if device.type == 'cuda':
  19. logger.info(torch.cuda.get_device_name(0))
  20. logger.info(f'Memory Usage:')
  21. logger.info(f'Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB')
  22. logger.info(f'Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB')
  23. def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
  24. torch.set_num_threads(1)
  25. can_start.wait()
  26. experts = [hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)]
  27. try:
  28. dummy_batch = torch.randn(batch_size, hid_dim)
  29. for batch_i in range(num_batches):
  30. expert = random.choice(experts)
  31. out = expert(dummy_batch)
  32. if backprop:
  33. out.sum().backward()
  34. except BaseException as e:
  35. benchmarking_failed.set()
  36. raise e
  37. def benchmark_throughput(num_experts=16, num_handlers=None, num_clients=128, num_batches_per_client=16,
  38. expert_cls='ffn', hid_dim=1024, batch_size=2048, max_batch_size=None, backprop=True,
  39. device=None, port=None):
  40. assert not hasattr(torch.cuda, 'is_initialized') or not torch.cuda.is_initialized() \
  41. or torch.device(device) == torch.device('cpu')
  42. assert expert_cls in layers.name_to_block
  43. port = port or find_open_port()
  44. max_batch_size = max_batch_size or batch_size * 4
  45. num_handlers = max(1, num_handlers or num_clients // 2)
  46. benchmarking_failed = mp.Event()
  47. can_start = mp.Event()
  48. timestamps = dict(started=time.perf_counter())
  49. try:
  50. # start clients and await server
  51. # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
  52. clients = [
  53. mp.Process(
  54. target=client_process, name=f'client_process-{i}',
  55. args=(can_start, benchmarking_failed, port, num_experts, batch_size,
  56. hid_dim, num_batches_per_client, backprop))
  57. for i in range(num_clients)]
  58. for client in clients:
  59. client.daemon = True
  60. client.start()
  61. timestamps['launched_clients'] = timestamps['began_launching_server'] = time.perf_counter()
  62. # start server
  63. device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
  64. experts = {}
  65. for i in range(num_experts):
  66. expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
  67. experts[f'expert{i}'] = hivemind.ExpertBackend(name=f'expert{i}',
  68. expert=expert,
  69. optimizer=torch.optim.Adam(expert.parameters()),
  70. args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
  71. outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
  72. max_batch_size=max_batch_size,
  73. )
  74. timestamps['created_experts'] = time.perf_counter()
  75. server = hivemind.Server(None, experts, listen_on=f"{hivemind.LOCALHOST}:{port}",
  76. num_connection_handlers=num_handlers, device=device)
  77. server.start()
  78. server.ready.wait()
  79. timestamps['server_ready'] = time.perf_counter()
  80. can_start.set()
  81. for client in clients:
  82. client.join()
  83. timestamps['clients_finished'] = time.perf_counter()
  84. except BaseException as e:
  85. benchmarking_failed.set()
  86. raise e
  87. finally:
  88. for client in clients:
  89. if client.is_alive():
  90. client.terminate()
  91. server.shutdown()
  92. timestamps['server_shutdown_finished'] = time.perf_counter()
  93. server.join()
  94. sys.stdout.flush()
  95. sys.stderr.flush()
  96. time_between = lambda key1, key2: \
  97. abs(timestamps[key2] - timestamps[key1]) if (key1 in timestamps and key2 in timestamps) else float('nan')
  98. total_examples = batch_size * num_clients * num_batches_per_client
  99. logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
  100. logger.info(f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, max_batch_size={max_batch_size},"
  101. f" expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}")
  102. logger.info(f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
  103. f"batch_size={batch_size}, backprop={backprop}")
  104. logger.info("Results: ")
  105. logger.info(f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
  106. f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
  107. f"{time_between('created_experts', 'server_ready') :.3f} s. networking)")
  108. logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
  109. logger.info(f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
  110. f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s.")
  111. logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
  112. if benchmarking_failed.is_set():
  113. logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
  114. print_device_info(device)
  115. sys.stdout.flush()
  116. sys.stderr.flush()
  117. assert not benchmarking_failed.is_set()
  118. if __name__ == "__main__":
  119. parser = argparse.ArgumentParser()
  120. parser.add_argument('--preset', type=str, default='default', required=False)
  121. parser.add_argument('--num_batches_per_client', type=int, default=16, required=False)
  122. args = parser.parse_args()
  123. if args.preset in ('default', 'ffn_forward_backward'):
  124. benchmark_throughput()
  125. elif args.preset == 'ffn_forward':
  126. benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
  127. elif args.preset == 'ffn_small_batch':
  128. benchmark_throughput(backprop=False, num_experts=4, batch_size=32, max_batch_size=8192,
  129. num_batches_per_client=args.num_batches_per_client)
  130. elif args.preset == 'ffn_small_batch_512clients':
  131. benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192,
  132. num_clients=512, num_batches_per_client=args.num_batches_per_client)
  133. elif args.preset == 'ffn_small_batch_512clients_32handlers':
  134. benchmark_throughput(backprop=True, num_experts=1, batch_size=1, max_batch_size=8192, num_handlers=32,
  135. num_clients=512, num_batches_per_client=args.num_batches_per_client)
  136. elif args.preset == 'ffn_massive':
  137. increase_file_limit()
  138. benchmark_throughput(backprop=False, num_clients=512, batch_size=512,
  139. max_batch_size=8192, num_batches_per_client=args.num_batches_per_client)
  140. elif args.preset == 'minimalistic':
  141. benchmark_throughput(num_experts=1, num_clients=1, num_handlers=1,
  142. num_batches_per_client=args.num_batches_per_client)
  143. elif args.preset == 'nop':
  144. benchmark_throughput(expert_cls='nop', backprop=False, num_batches_per_client=args.num_batches_per_client)
  145. else:
  146. raise ValueError(f"No such benchmark preset: {args.preset}")