|
@@ -1,6 +1,8 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import argparse
|
|
|
import fcntl
|
|
|
import json
|
|
|
-import math
|
|
|
import multiprocessing as mp
|
|
|
import os
|
|
|
import time
|
|
@@ -8,14 +10,19 @@ from collections import Counter
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Optional, Sequence, Union
|
|
|
|
|
|
+import configargparse
|
|
|
import torch
|
|
|
+
|
|
|
import torch.mps
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from transformers import PretrainedConfig
|
|
|
|
|
|
+from petals.constants import DTYPE_MAP
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
+from petals.utils.auto_config import AutoDistributedConfig
|
|
|
from petals.utils.convert_block import QuantType, convert_block
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
+from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
@@ -114,6 +121,7 @@ def measure_throughput_info(
|
|
|
*,
|
|
|
quant_type: QuantType,
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
+ measure_network: bool = True,
|
|
|
) -> Dict[str, float]:
|
|
|
logger.info(
|
|
|
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
|
@@ -139,14 +147,16 @@ def measure_throughput_info(
|
|
|
n_steps=10,
|
|
|
inference=False,
|
|
|
),
|
|
|
- "network_rps": measure_network_rps(config),
|
|
|
+ "network_rps": measure_network_rps(config, use_default=not measure_network),
|
|
|
}
|
|
|
|
|
|
|
|
|
def measure_network_rps(
|
|
|
- config: PretrainedConfig, *, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
|
|
|
+ config: PretrainedConfig, *, use_default=False, timeout: float = 60, default_speed: float = 100e6 # 100 Mbit/s
|
|
|
) -> Optional[float]:
|
|
|
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
|
+ if use_default:
|
|
|
+ return default_speed / bits_per_request
|
|
|
try:
|
|
|
pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
|
|
process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
|
|
@@ -207,13 +217,23 @@ def measure_compute_rps(
|
|
|
cache = None
|
|
|
elapsed = 0
|
|
|
dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
|
|
|
+ # with torch.profiler.profile(
|
|
|
+ # schedule=torch.profiler.schedule(wait=1, warmup=4, active=n_steps, repeat=1),
|
|
|
+ # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/profbf16_70b_qkv'),
|
|
|
+ # record_shapes=True,
|
|
|
+ # profile_memory=True,
|
|
|
+ # with_stack=True
|
|
|
+ # ) as prof:
|
|
|
_, cache = block.forward(dummy_input, use_cache=True) # Skip the 1st step to exclude the initialization time
|
|
|
synchronize(device)
|
|
|
+ # prof.step()
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
for _ in range(n_steps):
|
|
|
_, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
|
|
|
- synchronize(device)
|
|
|
+ synchronize(device)
|
|
|
+ # prof.step()
|
|
|
+
|
|
|
elapsed = time.perf_counter() - start_time
|
|
|
device_rps = n_steps * n_tokens / elapsed
|
|
|
|
|
@@ -245,3 +265,93 @@ def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
|
|
|
if quant_type != QuantType.NONE:
|
|
|
name += f", quantized to {quant_type.name.lower()}"
|
|
|
return name
|
|
|
+
|
|
|
+
|
|
|
+def parse_args():
|
|
|
+ # fmt:off
|
|
|
+ parser = configargparse.ArgParser(default_config_files=["config.yml"],
|
|
|
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
|
+ parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
|
|
|
+
|
|
|
+ group = parser.add_mutually_exclusive_group(required=True)
|
|
|
+ group.add_argument('--converted_model_name_or_path', type=str, default=None,
|
|
|
+ help="path or name of a pretrained model, converted with cli/convert_model.py")
|
|
|
+ group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
|
|
|
+
|
|
|
+ group = parser.add_mutually_exclusive_group(required=False)
|
|
|
+ group.add_argument("--token", type=str, default=None, help="Hugging Face hub auth token for .from_pretrained()")
|
|
|
+ group.add_argument("--use_auth_token", action="store_true", dest="token",
|
|
|
+ help="Read token saved by `huggingface-cli login")
|
|
|
+
|
|
|
+ parser.add_argument('--device', type=str, default=None, required=False,
|
|
|
+ help='all blocks will use this device in torch notation; default: cuda if available else cpu')
|
|
|
+ parser.add_argument("--torch_dtype", type=str, choices=DTYPE_MAP.keys(), default="auto",
|
|
|
+ help="Use this dtype to store block weights and do computations. "
|
|
|
+ "By default, respect the dtypes in the pre-trained state dict.")
|
|
|
+ parser.add_argument('--revision', type=str, default=None,
|
|
|
+ help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
|
|
|
+ "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
|
|
|
+
|
|
|
+ parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
|
|
|
+ help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
|
|
|
+ "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
|
|
|
+ "Default: 'int8' if GPU is available, 'none' otherwise")
|
|
|
+ parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
|
|
|
+ help=
|
|
|
+ "Split each block between the specified GPUs such that each device holds a portion of every "
|
|
|
+ "weight matrix. See https://huggingface.co/transformers/v4.9.0/parallelism.html#tensor-parallelism")
|
|
|
+
|
|
|
+ # fmt:on
|
|
|
+ args = parser.parse_args()
|
|
|
+ args.converted_model_name_or_path = args.model
|
|
|
+ return args
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ args = parse_args()
|
|
|
+ converted_model_name_or_path = get_compatible_model_repo(args.converted_model_name_or_path)
|
|
|
+ config = AutoDistributedConfig.from_pretrained(
|
|
|
+ converted_model_name_or_path,
|
|
|
+ use_auth_token=args.token,
|
|
|
+ revision=args.revision,
|
|
|
+ )
|
|
|
+
|
|
|
+ device = args.device
|
|
|
+ if device is None:
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ device = "cuda"
|
|
|
+ elif torch.backends.mps.is_available():
|
|
|
+ device = "mps"
|
|
|
+ else:
|
|
|
+ device = "cpu"
|
|
|
+ device = torch.device(device)
|
|
|
+ if device.type == "cuda" and device.index is None:
|
|
|
+ device = torch.device(device.type, index=0)
|
|
|
+
|
|
|
+ torch_dtype = resolve_block_dtype(config, DTYPE_MAP[args.torch_dtype])
|
|
|
+ if device.type == "cpu" and torch_dtype == torch.float16:
|
|
|
+ raise ValueError(
|
|
|
+ f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
|
|
|
+ )
|
|
|
+ if device.type == "mps" and torch_dtype == torch.bfloat16:
|
|
|
+ logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
|
|
|
+ torch_dtype = torch.float16
|
|
|
+
|
|
|
+ quant_type = args.quant_type
|
|
|
+ if quant_type is None:
|
|
|
+ if device.type == "cuda":
|
|
|
+ quant_type = QuantType.NF4 if config.model_type == "llama" else QuantType.INT8
|
|
|
+ else:
|
|
|
+ quant_type = QuantType.NONE
|
|
|
+
|
|
|
+ if args.tensor_parallel_devices is None:
|
|
|
+ args.tensor_parallel_devices = (device,)
|
|
|
+
|
|
|
+ measure_throughput_info(
|
|
|
+ config,
|
|
|
+ device,
|
|
|
+ torch_dtype,
|
|
|
+ quant_type=quant_type,
|
|
|
+ tensor_parallel_devices=args.tensor_parallel_devices,
|
|
|
+ measure_network=False,
|
|
|
+ )
|