throughput.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import fcntl
  2. import json
  3. import os
  4. import subprocess
  5. import tempfile
  6. import time
  7. from dataclasses import asdict, dataclass
  8. from pathlib import Path
  9. from typing import Dict, Union
  10. import torch
  11. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  12. from src import project_name
  13. from src.bloom.block import BloomBlock
  14. from src.bloom.model import BloomConfig
  15. from src.bloom.ops import build_alibi_tensor
  16. use_hivemind_log_handler("in_root_logger")
  17. logger = get_logger(__file__)
  18. DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
  19. DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
  20. SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
  21. @dataclass
  22. class ThroughputInfo:
  23. network_rps: float
  24. device_rps: Dict[str, float]
  25. def get_host_throughput(
  26. device: Union[str, torch.device],
  27. force_eval: bool = False,
  28. cache_path: str = DEFAULT_CACHE_PATH,
  29. lock_path: str = DEFAULT_LOCK_PATH,
  30. ) -> float:
  31. # We only keep the device type, assuming that the throughput is similar among all host's GPUs
  32. device = torch.device(device).type
  33. # We use the system-wide lock since only one process at a time can measure the host throughput
  34. os.makedirs(lock_path.parent, exist_ok=True)
  35. with open(lock_path, "wb") as lock_fd:
  36. logger.info("Loading throughput info")
  37. fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
  38. # The OS will release the lock when lock_fd is closed or the process is killed
  39. info = None
  40. try:
  41. if not force_eval and os.path.exists(cache_path):
  42. with open(cache_path) as cache_fd:
  43. info = ThroughputInfo(**json.load(cache_fd))
  44. if device not in info.device_rps:
  45. force_eval = True
  46. except Exception:
  47. logger.exception(f"Failed to read throughput info from {cache_path}")
  48. force_eval = True
  49. if force_eval or info is None:
  50. info = measure_throughput_info()
  51. try:
  52. os.makedirs(cache_path.parent, exist_ok=True)
  53. with open(cache_path, "w") as cache_fd:
  54. json.dump(asdict(info), cache_fd)
  55. except Exception:
  56. logger.exception(f"Failed to save throughput info in {cache_path}")
  57. throughput = min(info.network_rps, info.device_rps[device])
  58. return throughput
  59. def measure_throughput_info() -> ThroughputInfo:
  60. logger.info(
  61. "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
  62. )
  63. # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
  64. config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
  65. network_rps = measure_network_rps(config)
  66. device_rps = {"cpu": measure_device_rps("cpu", config)}
  67. if torch.cuda.is_available():
  68. device_rps["cuda"] = measure_device_rps("cuda", config)
  69. return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
  70. def measure_network_rps(config: BloomConfig) -> float:
  71. proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
  72. if proc.returncode != 0:
  73. raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
  74. network_info = json.loads(proc.stdout)
  75. bits_per_request = config.hidden_size * 32
  76. network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
  77. logger.info(
  78. f"Network throughput: "
  79. f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
  80. f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
  81. f"{network_rps:.2f} RPS"
  82. )
  83. return network_rps
  84. def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
  85. with torch.inference_mode():
  86. block = BloomBlock(config, layer_index).to(device)
  87. cache = None
  88. elapsed = 0
  89. for i in range(n_steps):
  90. dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
  91. alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
  92. start_time = time.perf_counter()
  93. _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
  94. elapsed += time.perf_counter() - start_time
  95. device_rps = n_steps / elapsed
  96. device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
  97. logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
  98. return device_rps