Kaynağa Gözat

Measure throughput for different configs, devices, and dtypes separately (#114)

Alexander Borzunov 2 yıl önce
ebeveyn
işleme
1ea44b0d3c
2 değiştirilmiş dosya ile 77 ekleme ve 53 silme
  1. 7 5
      src/petals/server/server.py
  2. 70 48
      src/petals/server/throughput.py

+ 7 - 5
src/petals/server/server.py

@@ -119,11 +119,6 @@ class Server:
 
         self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
 
-        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
-        if throughput in ["auto", "eval"]:
-            throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
-        self.throughput = throughput
-
         if isinstance(torch_dtype, str):
             torch_dtype = DTYPE_MAP[torch_dtype]
         assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
@@ -136,6 +131,13 @@ class Server:
         )
         self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
+        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+        if throughput in ["auto", "eval"]:
+            throughput = get_host_throughput(
+                self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
+            )
+        self.throughput = throughput
+
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if block_indices is not None:
             try:

+ 70 - 48
src/petals/server/throughput.py

@@ -4,9 +4,8 @@ import os
 import subprocess
 import tempfile
 import time
-from dataclasses import asdict, dataclass
+from hashlib import sha256
 from pathlib import Path
-from typing import Dict, Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -14,30 +13,26 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from petals.bloom.block import BloomBlock
 from petals.bloom.model import BloomConfig
 from petals.bloom.ops import build_alibi_tensor
+from petals.utils.convert_8bit import replace_8bit_linear
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput.json")
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json")
 DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
 
 
-@dataclass
-class ThroughputInfo:
-    network_rps: float
-    device_rps: Dict[str, float]
-
-
 def get_host_throughput(
-    device: Union[str, torch.device],
+    config: BloomConfig,
+    device: torch.device,
+    torch_dtype: torch.dtype,
+    *,
+    load_in_8bit: bool,
     force_eval: bool = False,
     cache_path: str = DEFAULT_CACHE_PATH,
     lock_path: str = DEFAULT_LOCK_PATH,
 ) -> float:
-    # We only keep the device type, assuming that the throughput is similar among all host's GPUs
-    device = torch.device(device).type
-
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     with open(lock_path, "wb") as lock_fd:
@@ -45,45 +40,49 @@ def get_host_throughput(
         fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
         # The OS will release the lock when lock_fd is closed or the process is killed
 
-        info = None
+        cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
+        cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
+        cache_key += f"_dtype_{_get_dtype_name(torch_dtype, load_in_8bit)}"
+
+        cache = {}
         try:
             if not force_eval and os.path.exists(cache_path):
                 with open(cache_path) as cache_fd:
-                    info = ThroughputInfo(**json.load(cache_fd))
-                if device not in info.device_rps:
-                    force_eval = True
+                    cache = json.load(cache_fd)
+                assert isinstance(cache, dict)
         except Exception:
             logger.exception(f"Failed to read throughput info from {cache_path}")
-            force_eval = True
+            cache = {}
+
+        if cache_key not in cache:
+            cache[cache_key] = measure_throughput_info(config, device, torch_dtype, load_in_8bit=load_in_8bit)
 
-        if force_eval or info is None:
-            info = measure_throughput_info()
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
                 with open(cache_path, "w") as cache_fd:
-                    json.dump(asdict(info), cache_fd)
+                    json.dump(cache, cache_fd)
             except Exception:
                 logger.exception(f"Failed to save throughput info in {cache_path}")
 
-    throughput = min(info.network_rps, info.device_rps[device])
-    return throughput
+    return cache[cache_key]
 
 
-def measure_throughput_info() -> ThroughputInfo:
+def measure_throughput_info(
+    config: BloomConfig,
+    device: torch.device,
+    dtype: torch.dtype,
+    *,
+    load_in_8bit: bool,
+) -> float:
+    """Measure network and compute throughput in forward pass tokens per second"""
+
     logger.info(
-        "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
+        "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
+    )
+    return min(
+        measure_network_rps(config),
+        measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
     )
-
-    # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
-    config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
-
-    network_rps = measure_network_rps(config)
-
-    device_rps = {"cpu": measure_device_rps("cpu", config)}
-    if torch.cuda.is_available():
-        device_rps["cuda"] = measure_device_rps("cuda", config)
-
-    return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
 
 
 def measure_network_rps(config: BloomConfig) -> float:
@@ -92,33 +91,56 @@ def measure_network_rps(config: BloomConfig) -> float:
         raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
     network_info = json.loads(proc.stdout)
 
-    bits_per_request = config.hidden_size * (16 if config.torch_dtype in (torch.float16, torch.bfloat16) else 32)
+    bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
 
     logger.info(
         f"Network throughput: "
         f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
         f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
-        f"{network_rps:.2f} RPS"
+        f"{network_rps:.1f} RPS"
     )
     return network_rps
 
 
-def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
+def measure_compute_rps(
+    config: BloomConfig,
+    device: torch.device,
+    dtype: torch.dtype,
+    *,
+    load_in_8bit: bool,
+    n_tokens: int = 16,
+    n_steps: int = 500,
+    layer_index: int = 0,
+) -> float:
     with torch.inference_mode():
-        block = BloomBlock(config, layer_index).to(device)
+        block = BloomBlock(config, layer_index).to(dtype)
+        if load_in_8bit:
+            block = replace_8bit_linear(block)
+        block = block.to(device)
+
         cache = None
         elapsed = 0
-        for i in range(n_steps):
-            dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
-            alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
+        for step in range(n_steps + 1):
+            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
+            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
 
             start_time = time.perf_counter()
             _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
-            elapsed += time.perf_counter() - start_time
-        device_rps = n_steps / elapsed
-
-    device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
-    logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
+            if step >= 1:  # Skip the 1st step to exclude the initialization time
+                elapsed += time.perf_counter() - start_time
+        device_rps = n_steps * n_tokens / elapsed
 
+    logger.info(
+        f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): "
+        f"{device_rps:.1f} RPS"
+    )
     return device_rps
+
+
+def _get_device_name(device: torch.device) -> str:
+    return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
+
+
+def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
+    return "8-bit" if load_in_8bit else str(dtype)