|
@@ -4,9 +4,8 @@ import os
|
|
|
import subprocess
|
|
|
import tempfile
|
|
|
import time
|
|
|
-from dataclasses import asdict, dataclass
|
|
|
+from hashlib import sha256
|
|
|
from pathlib import Path
|
|
|
-from typing import Dict, Union
|
|
|
|
|
|
import torch
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
@@ -14,30 +13,26 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
from petals.bloom.block import BloomBlock
|
|
|
from petals.bloom.model import BloomConfig
|
|
|
from petals.bloom.ops import build_alibi_tensor
|
|
|
+from petals.utils.convert_8bit import replace_8bit_linear
|
|
|
|
|
|
use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput.json")
|
|
|
+DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", "petals", "throughput_v2.json")
|
|
|
DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
|
|
|
|
|
|
|
|
|
-@dataclass
|
|
|
-class ThroughputInfo:
|
|
|
- network_rps: float
|
|
|
- device_rps: Dict[str, float]
|
|
|
-
|
|
|
-
|
|
|
def get_host_throughput(
|
|
|
- device: Union[str, torch.device],
|
|
|
+ config: BloomConfig,
|
|
|
+ device: torch.device,
|
|
|
+ torch_dtype: torch.dtype,
|
|
|
+ *,
|
|
|
+ load_in_8bit: bool,
|
|
|
force_eval: bool = False,
|
|
|
cache_path: str = DEFAULT_CACHE_PATH,
|
|
|
lock_path: str = DEFAULT_LOCK_PATH,
|
|
|
) -> float:
|
|
|
- # We only keep the device type, assuming that the throughput is similar among all host's GPUs
|
|
|
- device = torch.device(device).type
|
|
|
-
|
|
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
with open(lock_path, "wb") as lock_fd:
|
|
@@ -45,45 +40,49 @@ def get_host_throughput(
|
|
|
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
|
|
|
# The OS will release the lock when lock_fd is closed or the process is killed
|
|
|
|
|
|
- info = None
|
|
|
+ cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
|
|
|
+ cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
|
|
|
+ cache_key += f"_dtype_{_get_dtype_name(torch_dtype, load_in_8bit)}"
|
|
|
+
|
|
|
+ cache = {}
|
|
|
try:
|
|
|
if not force_eval and os.path.exists(cache_path):
|
|
|
with open(cache_path) as cache_fd:
|
|
|
- info = ThroughputInfo(**json.load(cache_fd))
|
|
|
- if device not in info.device_rps:
|
|
|
- force_eval = True
|
|
|
+ cache = json.load(cache_fd)
|
|
|
+ assert isinstance(cache, dict)
|
|
|
except Exception:
|
|
|
logger.exception(f"Failed to read throughput info from {cache_path}")
|
|
|
- force_eval = True
|
|
|
+ cache = {}
|
|
|
+
|
|
|
+ if cache_key not in cache:
|
|
|
+ cache[cache_key] = measure_throughput_info(config, device, torch_dtype, load_in_8bit=load_in_8bit)
|
|
|
|
|
|
- if force_eval or info is None:
|
|
|
- info = measure_throughput_info()
|
|
|
try:
|
|
|
os.makedirs(cache_path.parent, exist_ok=True)
|
|
|
with open(cache_path, "w") as cache_fd:
|
|
|
- json.dump(asdict(info), cache_fd)
|
|
|
+ json.dump(cache, cache_fd)
|
|
|
except Exception:
|
|
|
logger.exception(f"Failed to save throughput info in {cache_path}")
|
|
|
|
|
|
- throughput = min(info.network_rps, info.device_rps[device])
|
|
|
- return throughput
|
|
|
+ return cache[cache_key]
|
|
|
|
|
|
|
|
|
-def measure_throughput_info() -> ThroughputInfo:
|
|
|
+def measure_throughput_info(
|
|
|
+ config: BloomConfig,
|
|
|
+ device: torch.device,
|
|
|
+ dtype: torch.dtype,
|
|
|
+ *,
|
|
|
+ load_in_8bit: bool,
|
|
|
+) -> float:
|
|
|
+ """Measure network and compute throughput in forward pass tokens per second"""
|
|
|
+
|
|
|
logger.info(
|
|
|
- "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
|
|
|
+ "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
|
|
+ )
|
|
|
+ return min(
|
|
|
+ measure_network_rps(config),
|
|
|
+ measure_compute_rps(config, device, dtype, load_in_8bit=load_in_8bit),
|
|
|
)
|
|
|
-
|
|
|
- # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
|
|
|
- config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
|
|
|
-
|
|
|
- network_rps = measure_network_rps(config)
|
|
|
-
|
|
|
- device_rps = {"cpu": measure_device_rps("cpu", config)}
|
|
|
- if torch.cuda.is_available():
|
|
|
- device_rps["cuda"] = measure_device_rps("cuda", config)
|
|
|
-
|
|
|
- return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
|
|
|
|
|
|
|
|
|
def measure_network_rps(config: BloomConfig) -> float:
|
|
@@ -92,33 +91,56 @@ def measure_network_rps(config: BloomConfig) -> float:
|
|
|
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
|
|
|
network_info = json.loads(proc.stdout)
|
|
|
|
|
|
- bits_per_request = config.hidden_size * (16 if config.torch_dtype in (torch.float16, torch.bfloat16) else 32)
|
|
|
+ bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
|
|
|
|
|
logger.info(
|
|
|
f"Network throughput: "
|
|
|
f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
|
|
|
f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
|
|
|
- f"{network_rps:.2f} RPS"
|
|
|
+ f"{network_rps:.1f} RPS"
|
|
|
)
|
|
|
return network_rps
|
|
|
|
|
|
|
|
|
-def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
|
|
|
+def measure_compute_rps(
|
|
|
+ config: BloomConfig,
|
|
|
+ device: torch.device,
|
|
|
+ dtype: torch.dtype,
|
|
|
+ *,
|
|
|
+ load_in_8bit: bool,
|
|
|
+ n_tokens: int = 16,
|
|
|
+ n_steps: int = 500,
|
|
|
+ layer_index: int = 0,
|
|
|
+) -> float:
|
|
|
with torch.inference_mode():
|
|
|
- block = BloomBlock(config, layer_index).to(device)
|
|
|
+ block = BloomBlock(config, layer_index).to(dtype)
|
|
|
+ if load_in_8bit:
|
|
|
+ block = replace_8bit_linear(block)
|
|
|
+ block = block.to(device)
|
|
|
+
|
|
|
cache = None
|
|
|
elapsed = 0
|
|
|
- for i in range(n_steps):
|
|
|
- dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
|
|
|
- alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
|
|
|
+ for step in range(n_steps + 1):
|
|
|
+ dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
|
|
|
+ alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
_, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
|
|
|
- elapsed += time.perf_counter() - start_time
|
|
|
- device_rps = n_steps / elapsed
|
|
|
-
|
|
|
- device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
|
|
|
- logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
|
|
|
+ if step >= 1: # Skip the 1st step to exclude the initialization time
|
|
|
+ elapsed += time.perf_counter() - start_time
|
|
|
+ device_rps = n_steps * n_tokens / elapsed
|
|
|
|
|
|
+ logger.info(
|
|
|
+ f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): "
|
|
|
+ f"{device_rps:.1f} RPS"
|
|
|
+ )
|
|
|
return device_rps
|
|
|
+
|
|
|
+
|
|
|
+def _get_device_name(device: torch.device) -> str:
|
|
|
+ return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
|
|
|
+
|
|
|
+
|
|
|
+def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
|
|
|
+ return "8-bit" if load_in_8bit else str(dtype)
|