benchmark_throughput.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import argparse
  2. import multiprocessing as mp
  3. import random
  4. import sys
  5. import time
  6. import torch
  7. from hivemind.moe.client import RemoteExpert
  8. from hivemind.moe.server import ExpertBackend, Server
  9. from hivemind.moe.server.layers import name_to_block
  10. from hivemind.utils.limits import increase_file_limit
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. from hivemind.utils.networking import LOCALHOST, get_free_port
  13. from hivemind.utils.tensor_descr import BatchTensorDescriptor
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger(__name__)
  16. def print_device_info(device=None):
  17. """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
  18. device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
  19. logger.info(f"Using device: {device}")
  20. # Additional Info when using cuda
  21. if device.type == "cuda":
  22. logger.info(torch.cuda.get_device_name(0))
  23. logger.info(f"Memory Usage:")
  24. logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
  25. logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
  26. def client_process(can_start, benchmarking_failed, port, num_experts, batch_size, hid_dim, num_batches, backprop=True):
  27. torch.set_num_threads(1)
  28. can_start.wait()
  29. experts = [RemoteExpert(f"expert{i}", endpoint=f"{LOCALHOST}:{port}") for i in range(num_experts)]
  30. try:
  31. dummy_batch = torch.randn(batch_size, hid_dim)
  32. for batch_i in range(num_batches):
  33. expert = random.choice(experts)
  34. out = expert(dummy_batch)
  35. if backprop:
  36. out.sum().backward()
  37. except BaseException as e:
  38. benchmarking_failed.set()
  39. raise e
  40. def benchmark_throughput(
  41. num_experts=16,
  42. num_handlers=None,
  43. num_clients=128,
  44. num_batches_per_client=16,
  45. expert_cls="ffn",
  46. hid_dim=1024,
  47. batch_size=2048,
  48. max_batch_size=None,
  49. backprop=True,
  50. device=None,
  51. port=None,
  52. ):
  53. assert (
  54. not hasattr(torch.cuda, "is_initialized")
  55. or not torch.cuda.is_initialized()
  56. or torch.device(device) == torch.device("cpu")
  57. )
  58. assert expert_cls in name_to_block
  59. port = port or get_free_port()
  60. max_batch_size = max_batch_size or batch_size * 4
  61. num_handlers = max(1, num_handlers or num_clients // 2)
  62. benchmarking_failed = mp.Event()
  63. can_start = mp.Event()
  64. timestamps = dict(started=time.perf_counter())
  65. try:
  66. # start clients and await server
  67. # Note: client processes must be launched BEFORE touching gpu, even torch.cuda.is_available can cause trouble
  68. clients = [
  69. mp.Process(
  70. target=client_process,
  71. name=f"client_process-{i}",
  72. args=(
  73. can_start,
  74. benchmarking_failed,
  75. port,
  76. num_experts,
  77. batch_size,
  78. hid_dim,
  79. num_batches_per_client,
  80. backprop,
  81. ),
  82. )
  83. for i in range(num_clients)
  84. ]
  85. for client in clients:
  86. client.daemon = True
  87. client.start()
  88. timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
  89. # start server
  90. device = device or ("cuda" if torch.cuda.is_available() else "cpu")
  91. experts = {}
  92. for i in range(num_experts):
  93. expert = torch.jit.script(name_to_block[expert_cls](hid_dim))
  94. experts[f"expert{i}"] = ExpertBackend(
  95. name=f"expert{i}",
  96. expert=expert,
  97. optimizer=torch.optim.Adam(expert.parameters()),
  98. args_schema=(BatchTensorDescriptor(hid_dim),),
  99. outputs_schema=BatchTensorDescriptor(hid_dim),
  100. max_batch_size=max_batch_size,
  101. )
  102. timestamps["created_experts"] = time.perf_counter()
  103. server = Server(
  104. None,
  105. experts,
  106. listen_on=f"{LOCALHOST}:{port}",
  107. num_connection_handlers=num_handlers,
  108. device=device,
  109. )
  110. server.start()
  111. server.ready.wait()
  112. timestamps["server_ready"] = time.perf_counter()
  113. can_start.set()
  114. for client in clients:
  115. client.join()
  116. timestamps["clients_finished"] = time.perf_counter()
  117. except BaseException as e:
  118. benchmarking_failed.set()
  119. raise e
  120. finally:
  121. for client in clients:
  122. if client.is_alive():
  123. client.terminate()
  124. server.shutdown()
  125. timestamps["server_shutdown_finished"] = time.perf_counter()
  126. server.join()
  127. sys.stdout.flush()
  128. sys.stderr.flush()
  129. time_between = (
  130. lambda key1, key2: abs(timestamps[key2] - timestamps[key1])
  131. if (key1 in timestamps and key2 in timestamps)
  132. else float("nan")
  133. )
  134. total_examples = batch_size * num_clients * num_batches_per_client
  135. logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
  136. logger.info(
  137. f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
  138. f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
  139. )
  140. logger.info(
  141. f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
  142. f"batch_size={batch_size}, backprop={backprop}"
  143. )
  144. logger.info("Results: ")
  145. logger.info(
  146. f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
  147. f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
  148. f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
  149. )
  150. logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
  151. logger.info(
  152. f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
  153. f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
  154. )
  155. logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
  156. if benchmarking_failed.is_set():
  157. logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
  158. print_device_info(device)
  159. sys.stdout.flush()
  160. sys.stderr.flush()
  161. assert not benchmarking_failed.is_set()
  162. if __name__ == "__main__":
  163. parser = argparse.ArgumentParser()
  164. parser.add_argument("--preset", type=str, default="default", required=False)
  165. parser.add_argument("--num_batches_per_client", type=int, default=16, required=False)
  166. args = parser.parse_args()
  167. if args.preset in ("default", "ffn_forward_backward"):
  168. benchmark_throughput()
  169. elif args.preset == "ffn_forward":
  170. benchmark_throughput(backprop=False, num_batches_per_client=args.num_batches_per_client)
  171. elif args.preset == "ffn_small_batch":
  172. benchmark_throughput(
  173. backprop=False,
  174. num_experts=4,
  175. batch_size=32,
  176. max_batch_size=8192,
  177. num_batches_per_client=args.num_batches_per_client,
  178. )
  179. elif args.preset == "ffn_small_batch_512clients":
  180. benchmark_throughput(
  181. backprop=True,
  182. num_experts=1,
  183. batch_size=1,
  184. max_batch_size=8192,
  185. num_clients=512,
  186. num_batches_per_client=args.num_batches_per_client,
  187. )
  188. elif args.preset == "ffn_small_batch_512clients_32handlers":
  189. benchmark_throughput(
  190. backprop=True,
  191. num_experts=1,
  192. batch_size=1,
  193. max_batch_size=8192,
  194. num_handlers=32,
  195. num_clients=512,
  196. num_batches_per_client=args.num_batches_per_client,
  197. )
  198. elif args.preset == "ffn_massive":
  199. increase_file_limit()
  200. benchmark_throughput(
  201. backprop=False,
  202. num_clients=512,
  203. batch_size=512,
  204. max_batch_size=8192,
  205. num_batches_per_client=args.num_batches_per_client,
  206. )
  207. elif args.preset == "minimalistic":
  208. benchmark_throughput(
  209. num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
  210. )
  211. elif args.preset == "nop":
  212. benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
  213. else:
  214. raise ValueError(f"No such benchmark preset: {args.preset}")