benchmark_throughput.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import argparse
  2. import multiprocessing as mp
  3. import random
  4. import sys
  5. import time
  6. import torch
  7. import hivemind
  8. from hivemind import get_free_port
  9. from hivemind.moe.server import layers
  10. from hivemind.utils.limits import increase_file_limit
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. use_hivemind_log_handler("in_root_logger")
  13. logger = get_logger(__name__)
  14. def print_device_info(device=None):
  15. """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
  16. device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
  17. logger.info(f"Using device: {device}")
  18. # Additional Info when using cuda
  19. if device.type == "cuda":
  20. logger.info(torch.cuda.get_device_name(0))
  21. logger.info(f"Memory Usage:")
  22. logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
  23. logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
  24. def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
  25. torch.set_num_threads(1)
  26. can_start.wait()
  27. experts = [
  28. hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
  29. ]
  30. try:
  31. dummy_batch = torch.randn(batch_size, hid_dim)
  32. for batch_i in range(num_batches):
  33. expert = random.choice(experts)
  34. out = expert(dummy_batch)
  35. if backprop:
  36. out.sum().backward()
  37. except BaseException as e:
  38. benchmarking_failed.set()
  39. raise e
  40. def benchmark_throughput(
  41. num_experts=16,
  42. num_handlers=None,
  43. num_clients=128,
  44. num_batches_per_client=16,
  45. expert_cls="ffn",
  46. hid_dim=1024,
  47. batch_size=2048,
  48. max_batch_size=None,
  49. backprop=True,
  50. device=None,
  51. port=None,
  52. ):
  53. assert (
  54. not hasattr(torch.cuda, "is_initialized")
  55. or not torch.cuda.is_initialized()
  56. or torch.device(device) == torch.device("cpu")
  57. )
  58. assert expert_cls in layers.name_to_block
  59. port = port or get_free_port()
  60. max_batch_size = max_batch_size or batch_size * 4
  61. num_handlers = max(1, num_handlers or num_clients // 2)
  62. benchmarking_failed = mp.Event()
  63. can_start = mp.Event()
  64. timestamps = dict(started=time.perf_counter())
  65. try:
  66. # start clients and await server
  67. # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
  68. clients = [
  69. mp.Process(
  70. target=client_process,
  71. name=f"client_process-{i}",
  72. args=(
  73. can_start,
  74. benchmarking_failed,
  75. port,
  76. num_experts,
  77. batch_size,
  78. hid_dim,
  79. num_batches_per_client,
  80. backprop,
  81. ),
  82. )
  83. for i in range(num_clients)
  84. ]
  85. for client in clients:
  86. client.daemon = True
  87. client.start()
  88. timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
  89. # start server
  90. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  91. experts = {}
  92. for i in range(num_experts):
  93. expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
  94. experts[f"expert{i}"] = hivemind.ExpertBackend(
  95. name=f"expert{i}",
  96. expert=expert,
  97. optimizer=torch.optim.Adam(expert.parameters()),
  98. args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
  99. outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
  100. max_batch_size=max_batch_size,
  101. )
  102. timestamps["created_experts"] = time.perf_counter()
  103. server = hivemind.moe.Server(
  104. None,
  105. experts,
  106. listen_on=f"{hivemind.LOCALHOST}:{port}",
  107. num_connection_handlers=num_handlers,
  108. device=device,
  109. )
  110. server.start()
  111. server.ready.wait()
  112. timestamps["server_ready"] = time.perf_counter()
  113. can_start.set()
  114. for client in clients:
  115. client.join()
  116. timestamps["clients_finished"] = time.perf_counter()
  117. except BaseException as e:
  118. benchmarking_failed.set()
  119. raise e
  120. finally:
  121. for client in clients:
  122. if client.is_alive():
  123. client.terminate()
  124. server.shutdown()
  125. timestamps["server_shutdown_finished"] = time.perf_counter()
  126. server.join()
  127. sys.stdout.flush()
  128. sys.stderr.flush()
  129. time_between = (
  130. lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
  131. if (key1 in timestamps and key2 in timestamps)
  132. else float("nan")
  133. )
  134. total_examples = batch_size * num_clients * num_batches_per_client
  135. logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
  136. logger.info(
  137. f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
  138. f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
  139. )
  140. logger.info(
  141. f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
  142. f"batch_size={batch_size}, backprop={backprop}"
  143. )
  144. logger.info("Results: ")
  145. logger.info(
  146. f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
  147. f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
  148. f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
  149. )
  150. logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
  151. logger.info(
  152. f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
  153. f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
  154. )
  155. logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
  156. if benchmarking_failed.is_set():
  157. logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
  158. print_device_info(device)
  159. sys.stdout.flush()
  160. sys.stderr.flush()
  161. assert not benchmarking_failed.is_set()
  162. if __name__ == "__main__":
  163. parser = argparse.ArgumentParser()
  164. parser.add_argument("--preset", type=str, default="default", required=False)
  165. parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
  166. args = parser.parse_args()
  167. if args.preset in ("default", "ffn_forward_backward"):
  168. benchmark_throughput()
  169. elif args.preset == "ffn_forward":
  170. benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
  171. elif args.preset == "ffn_small_batch":
  172. benchmark_throughput(
  173. backprop=False,
  174. num_experts=4,
  175. batch_size=32,
  176. max_batch_size=8192,
  177. num_batches_per_client=args.num_batches_per_client,
  178. )
  179. elif args.preset == "ffn_small_batch_512clients":
  180. benchmark_throughput(
  181. backprop=True,
  182. num_experts=1,
  183. batch_size=1,
  184. max_batch_size=8192,
  185. num_clients=512,
  186. num_batches_per_client=args.num_batches_per_client,
  187. )
  188. elif args.preset == "ffn_small_batch_512clients_32handlers":
  189. benchmark_throughput(
  190. backprop=True,
  191. num_experts=1,
  192. batch_size=1,
  193. max_batch_size=8192,
  194. num_handlers=32,
  195. num_clients=512,
  196. num_batches_per_client=args.num_batches_per_client,
  197. )
  198. elif args.preset == "ffn_massive":
  199. increase_file_limit()
  200. benchmark_throughput(
  201. backprop=False,
  202. num_clients=512,
  203. batch_size=512,
  204. max_batch_size=8192,
  205. num_batches_per_client=args.num_batches_per_client,
  206. )
  207. elif args.preset == "minimalistic":
  208. benchmark_throughput(
  209. num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
  210. )
  211. elif args.preset == "nop":
  212. benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
  213. else:
  214. raise ValueError(f"No such benchmark preset: {args.preset}")