benchmark_throughput.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import argparse
  2. import multiprocessing as mp
  3. import random
  4. import sys
  5. import time
  6. import torch
  7. import hivemind
  8. from hivemind import get_free_port
  9. from hivemind.moe.server import layers
  10. from hivemind.utils.limits import increase_file_limit
  11. from hivemind.utils.logging import get_logger
  12. logger = get_logger(__name__)
  13. def print_device_info(device=None):
  14. """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
  15. device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
  16. logger.info(f"Using device: {device}")
  17. # Additional Info when using cuda
  18. if device.type == "cuda":
  19. logger.info(torch.cuda.get_device_name(0))
  20. logger.info(f"Memory Usage:")
  21. logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
  22. logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
  23. def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
  24. torch.set_num_threads(1)
  25. can_start.wait()
  26. experts = [
  27. hivemind.RemoteExpert(f"expert{i}", endpoint=f"{hivemind.LOCALHOST}:{port}") for i in range(num_experts)
  28. ]
  29. try:
  30. dummy_batch = torch.randn(batch_size, hid_dim)
  31. for batch_i in range(num_batches):
  32. expert = random.choice(experts)
  33. out = expert(dummy_batch)
  34. if backprop:
  35. out.sum().backward()
  36. except BaseException as e:
  37. benchmarking_failed.set()
  38. raise e
  39. def benchmark_throughput(
  40. num_experts=16,
  41. num_handlers=None,
  42. num_clients=128,
  43. num_batches_per_client=16,
  44. expert_cls="ffn",
  45. hid_dim=1024,
  46. batch_size=2048,
  47. max_batch_size=None,
  48. backprop=True,
  49. device=None,
  50. port=None,
  51. ):
  52. assert (
  53. not hasattr(torch.cuda, "is_initialized")
  54. or not torch.cuda.is_initialized()
  55. or torch.device(device) == torch.device("cpu")
  56. )
  57. assert expert_cls in layers.name_to_block
  58. port = port or get_free_port()
  59. max_batch_size = max_batch_size or batch_size * 4
  60. num_handlers = max(1, num_handlers or num_clients // 2)
  61. benchmarking_failed = mp.Event()
  62. can_start = mp.Event()
  63. timestamps = dict(started=time.perf_counter())
  64. try:
  65. # start clients and await server
  66. # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
  67. clients = [
  68. mp.Process(
  69. target=client_process,
  70. name=f"client_process-{i}",
  71. args=(
  72. can_start,
  73. benchmarking_failed,
  74. port,
  75. num_experts,
  76. batch_size,
  77. hid_dim,
  78. num_batches_per_client,
  79. backprop,
  80. ),
  81. )
  82. for i in range(num_clients)
  83. ]
  84. for client in clients:
  85. client.daemon = True
  86. client.start()
  87. timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
  88. # start server
  89. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  90. experts = {}
  91. for i in range(num_experts):
  92. expert = torch.jit.script(layers.name_to_block[expert_cls](hid_dim))
  93. experts[f"expert{i}"] = hivemind.ExpertBackend(
  94. name=f"expert{i}",
  95. expert=expert,
  96. optimizer=torch.optim.Adam(expert.parameters()),
  97. args_schema=(hivemind.BatchTensorDescriptor(hid_dim),),
  98. outputs_schema=hivemind.BatchTensorDescriptor(hid_dim),
  99. max_batch_size=max_batch_size,
  100. )
  101. timestamps["created_experts"] = time.perf_counter()
  102. server = hivemind.moe.Server(
  103. None,
  104. experts,
  105. listen_on=f"{hivemind.LOCALHOST}:{port}",
  106. num_connection_handlers=num_handlers,
  107. device=device,
  108. )
  109. server.start()
  110. server.ready.wait()
  111. timestamps["server_ready"] = time.perf_counter()
  112. can_start.set()
  113. for client in clients:
  114. client.join()
  115. timestamps["clients_finished"] = time.perf_counter()
  116. except BaseException as e:
  117. benchmarking_failed.set()
  118. raise e
  119. finally:
  120. for client in clients:
  121. if client.is_alive():
  122. client.terminate()
  123. server.shutdown()
  124. timestamps["server_shutdown_finished"] = time.perf_counter()
  125. server.join()
  126. sys.stdout.flush()
  127. sys.stderr.flush()
  128. time_between = (
  129. lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
  130. if (key1 in timestamps and key2 in timestamps)
  131. else float("nan")
  132. )
  133. total_examples = batch_size * num_clients * num_batches_per_client
  134. logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
  135. logger.info(
  136. f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
  137. f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
  138. )
  139. logger.info(
  140. f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
  141. f"batch_size={batch_size}, backprop={backprop}"
  142. )
  143. logger.info("Results: ")
  144. logger.info(
  145. f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
  146. f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
  147. f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
  148. )
  149. logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
  150. logger.info(
  151. f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
  152. f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
  153. )
  154. logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
  155. if benchmarking_failed.is_set():
  156. logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
  157. print_device_info(device)
  158. sys.stdout.flush()
  159. sys.stderr.flush()
  160. assert not benchmarking_failed.is_set()
  161. if __name__ == "__main__":
  162. parser = argparse.ArgumentParser()
  163. parser.add_argument("--preset", type=str, default="default", required=False)
  164. parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
  165. args = parser.parse_args()
  166. if args.preset in ("default", "ffn_forward_backward"):
  167. benchmark_throughput()
  168. elif args.preset == "ffn_forward":
  169. benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
  170. elif args.preset == "ffn_small_batch":
  171. benchmark_throughput(
  172. backprop=False,
  173. num_experts=4,
  174. batch_size=32,
  175. max_batch_size=8192,
  176. num_batches_per_client=args.num_batches_per_client,
  177. )
  178. elif args.preset == "ffn_small_batch_512clients":
  179. benchmark_throughput(
  180. backprop=True,
  181. num_experts=1,
  182. batch_size=1,
  183. max_batch_size=8192,
  184. num_clients=512,
  185. num_batches_per_client=args.num_batches_per_client,
  186. )
  187. elif args.preset == "ffn_small_batch_512clients_32handlers":
  188. benchmark_throughput(
  189. backprop=True,
  190. num_experts=1,
  191. batch_size=1,
  192. max_batch_size=8192,
  193. num_handlers=32,
  194. num_clients=512,
  195. num_batches_per_client=args.num_batches_per_client,
  196. )
  197. elif args.preset == "ffn_massive":
  198. increase_file_limit()
  199. benchmark_throughput(
  200. backprop=False,
  201. num_clients=512,
  202. batch_size=512,
  203. max_batch_size=8192,
  204. num_batches_per_client=args.num_batches_per_client,
  205. )
  206. elif args.preset == "minimalistic":
  207. benchmark_throughput(
  208. num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
  209. )
  210. elif args.preset == "nop":
  211. benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
  212. else:
  213. raise ValueError(f"No such benchmark preset: {args.preset}")