123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- import fcntl
- import json
- import math
- import multiprocessing as mp
- import os
- import time
- from collections import Counter
- from pathlib import Path
- from typing import Dict, Optional, Sequence, Union
- import torch
- import torch.mps
- from hivemind.utils.logging import get_logger
- from transformers import PretrainedConfig
- from petals.server.block_utils import resolve_block_dtype
- from petals.utils.convert_block import QuantType, convert_block
- from petals.utils.disk_cache import DEFAULT_CACHE_DIR
- logger = get_logger(__name__)
- try:
- import speedtest
- except ImportError:
- raise ImportError("Please `pip install speedtest-cli==2.1.3`")
- if not hasattr(speedtest, "Speedtest"):
- raise ImportError(
- "You are using the wrong speedtest module. Please replace speedtest with speedtest-cli.\n"
- "To do that, run `pip uninstall -y speedtest`. Depending on your python environment, "
- "you may need to run uninstall speedtest two or more times, until it says 'not installed'.\n"
- "After that, please `pip install speedtest-cli==2.1.3` to install the correct version."
- )
- def get_server_throughput(
- model_name: str,
- config: PretrainedConfig,
- device: torch.device,
- dtype: Union[str, torch.dtype],
- *,
- num_blocks: int,
- quant_type: QuantType,
- tensor_parallel_devices: Sequence[torch.device],
- reachable_via_relay: bool,
- relay_penalty: float = 0.2,
- force_eval: bool = False,
- cache_dir: Optional[str] = None,
- ) -> Dict[str, float]:
- dtype = resolve_block_dtype(config, dtype)
- if cache_dir is None:
- cache_dir = DEFAULT_CACHE_DIR
- lock_path = Path(cache_dir, "throughput.lock")
- cache_path = Path(cache_dir, "throughput_v5.json")
- # We use the system-wide lock since only one process at a time can measure the host throughput
- os.makedirs(lock_path.parent, exist_ok=True)
- with open(lock_path, "wb") as lock_fd:
- logger.info("Loading throughput info")
- fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
- # The OS will release the lock when lock_fd is closed or the process is killed
- cache_key = f"model_{model_name}"
- cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
- cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}"
- if len(tensor_parallel_devices) > 1:
- for i, device_i in enumerate(tensor_parallel_devices):
- cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
- cache = {}
- try:
- if not force_eval and os.path.exists(cache_path):
- with open(cache_path) as cache_fd:
- cache = json.load(cache_fd)
- assert isinstance(cache, dict)
- except Exception:
- logger.exception(f"Failed to read throughput info from {cache_path}")
- cache = {}
- if cache_key not in cache:
- cache[cache_key] = measure_throughput_info(
- config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
- )
- try:
- os.makedirs(cache_path.parent, exist_ok=True)
- with open(cache_path, "w") as cache_fd:
- json.dump(cache, cache_fd)
- except Exception:
- logger.exception(f"Failed to save throughput info in {cache_path}")
- throughput_info = cache[cache_key]
- # Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
- # Assuming the start block index is distributed uniformly, the average number of blocks used per request is
- # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
- average_blocks_used = (num_blocks + 1) / 2
- throughput = throughput_info["forward_rps"] / average_blocks_used
- network_rps = throughput_info["network_rps"] * (relay_penalty if reachable_via_relay else 1)
- throughput = min(throughput, network_rps)
- throughput_info["throughput"] = throughput
- logger.info(f"Reporting throughput: {throughput:.1f} tokens/sec for {num_blocks} blocks")
- return throughput_info
- def measure_throughput_info(
- config: PretrainedConfig,
- device: torch.device,
- dtype: torch.dtype,
- *,
- quant_type: QuantType,
- tensor_parallel_devices: Sequence[torch.device],
- ) -> Dict[str, float]:
- logger.info(
- "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
- )
- return {
- "inference_rps": measure_compute_rps(
- config,
- device,
- dtype,
- quant_type=quant_type,
- tensor_parallel_devices=tensor_parallel_devices,
- n_tokens=1,
- n_steps=100,
- inference=True,
- ),
- "forward_rps": measure_compute_rps(
- config,
- device,
- dtype,
- quant_type=quant_type,
- tensor_parallel_devices=tensor_parallel_devices,
- n_tokens=1024,
- n_steps=10,
- inference=False,
- ),
- "network_rps": measure_network_rps(config),
- }
- def measure_network_rps(
- config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
- ) -> Optional[float]:
- bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
- try:
- pipe_recv, pipe_send = mp.Pipe(duplex=False)
- process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
- process.start()
- if not pipe_recv.poll(timeout):
- process.terminate()
- raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
- network_info = pipe_recv.recv()
- if "exception" in network_info:
- raise RuntimeError(f"speedtest failed: {network_info['exception']}")
- network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
- if network_rps == 0:
- raise RuntimeError("speedtest has returned network_rps == 0")
- logger.info(
- f"Network throughput: {network_rps:.1f} tokens/sec "
- f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
- f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
- )
- return network_rps
- except RuntimeError as e:
- logger.info(f"Network throughput is not available: {e}. Using default of {default_speed / 1e6:.2f} Mbit/s")
- return default_speed / bits_per_request
- def _measure_bits_per_second(pipe_send: mp.Pipe):
- try:
- s = speedtest.Speedtest()
- s.get_servers()
- s.get_best_server()
- s.download()
- s.upload()
- pipe_send.send(s.results.dict())
- except Exception as e:
- pipe_send.send({"exception": repr(e)})
- def measure_compute_rps(
- config: PretrainedConfig,
- device: torch.device,
- dtype: torch.dtype,
- *,
- quant_type: QuantType,
- tensor_parallel_devices: Sequence[torch.device],
- n_tokens: int,
- n_steps: int,
- inference: bool,
- ) -> float:
- device = torch.device(device)
- if not tensor_parallel_devices:
- tensor_parallel_devices = (device,)
- with torch.inference_mode():
- block = config.block_class(config).to(dtype)
- block = convert_block(block, 0, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
- cache = None
- elapsed = 0
- dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
- _, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
- synchronize(device)
- start_time = time.perf_counter()
- for _ in range(n_steps):
- _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
- synchronize(device)
- elapsed = time.perf_counter() - start_time
- device_rps = n_steps * n_tokens / elapsed
- devices_repr = get_device_name(device)
- if len(tensor_parallel_devices) > 1:
- device_names = tuple(map(get_device_name, map(torch.device, tensor_parallel_devices)))
- devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
- logger.info(
- f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} tokens/sec per block "
- f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
- )
- return device_rps
- def synchronize(device: torch.device):
- if device.type == "cuda":
- torch.cuda.synchronize(device)
- elif device.type == "mps":
- torch.mps.synchronize()
- def get_device_name(device: torch.device) -> str:
- return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else device.type.upper()
- def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
- name = str(dtype).replace("torch.", "")
- if quant_type != QuantType.NONE:
- name += f", quantized to {quant_type.name.lower()}"
- return name
|