Prechádzať zdrojové kódy

Support loading blocks in 4-bit (QLoRA NF4 format, disabled by default) (#333)

Alexander Borzunov 2 rokov pred
rodič
commit
de930918a0

+ 1 - 1
setup.cfg

@@ -32,7 +32,7 @@ packages = find:
 python_requires = >=3.7
 install_requires =
     torch>=1.12
-    bitsandbytes==0.38.0.post2
+    bitsandbytes==0.39.1
     accelerate>=0.16.0,<1.0.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3

+ 8 - 6
src/petals/cli/run_server.py

@@ -8,6 +8,7 @@ from humanfriendly import parse_size
 
 from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
 from petals.server.server import Server
+from petals.utils.convert_block import QuantType
 from petals.utils.version import validate_version
 
 logger = get_logger(__name__)
@@ -133,9 +134,10 @@ def main():
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
     parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
-    parser.add_argument('--load_in_8bit', type=str, default=None,
-                        help="Convert the loaded transformer blocks into mixed-8bit quantized model. "
-                             "Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
+    parser.add_argument('--quant_type', type=str, default=None, choices=[choice.name.lower() for choice in QuantType],
+                        help="Quantize blocks to 8-bit (int8 from the LLM.int8() paper) or "
+                             "4-bit (nf4 from the QLoRA paper) formats to save GPU memory. "
+                             "Default: 'int8' if GPU is available, 'none' otherwise")
     parser.add_argument("--tensor_parallel_devices", nargs='+', default=None,
                         help=
                         "Split each block between the specified GPUs such that each device holds a portion of every "
@@ -186,9 +188,9 @@ def main():
     if args.pop("new_swarm"):
         args["initial_peers"] = []
 
-    load_in_8bit = args.pop("load_in_8bit")
-    if load_in_8bit is not None:
-        args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
+    quant_type = args.pop("quant_type")
+    if quant_type is not None:
+        args["quant_type"] = QuantType[quant_type.upper()]
 
     validate_version()
 

+ 16 - 11
src/petals/server/block_utils.py

@@ -4,6 +4,8 @@ import torch
 from accelerate import init_empty_weights
 from transformers import PretrainedConfig
 
+from petals.utils.convert_block import QuantType
+
 
 def resolve_block_dtype(config: PretrainedConfig, dtype: Union[str, torch.dtype]) -> torch.dtype:
     """If dtype is "auto", resolves it using BloomConfig. Returns `dtype` intact otherwise."""
@@ -19,27 +21,30 @@ def get_block_size(
     location: str,
     *,
     dtype: Optional[Union[str, torch.dtype]] = None,
-    load_in_8bit: Optional[bool] = None,
+    quant_type: QuantType = QuantType.NONE,
     eps: float = 0.01,  # eps accounts for ~1% of metainfo for tensor descriptions, quantization tables, etc.
 ) -> int:
     if location == "memory":
         assert (
-            dtype is not None and load_in_8bit is not None
-        ), 'get_block_size(..., location="memory") requires to specify dtype and load_in_8bit for calculations'
+            dtype is not None and quant_type is not None
+        ), 'get_block_size(..., location="memory") requires to specify dtype and quant_type for calculations'
 
     with init_empty_weights(include_buffers=True):
         block = config.block_class(config)
         n_params = sum(param.numel() for param in block.parameters())
 
-    if location == "memory" and load_in_8bit:
-        # Note: We may need a larger eps here for models of size < 1B
-        return n_params * (1 + eps)
-
     if location == "memory":
-        dtype = resolve_block_dtype(config, dtype)
+        if quant_type == QuantType.NONE:
+            dtype = resolve_block_dtype(config, dtype)
+            bytes_per_value = torch.finfo(dtype).bits // 8
+        elif quant_type == QuantType.INT8:
+            bytes_per_value = 1
+        elif quant_type == QuantType.NF4:
+            bytes_per_value = 4.25 / 8  # Bitness of NF4 with this config (measured empirically)
+        else:
+            raise ValueError(f"Unsupported quant_type={quant_type}")
     elif location == "disk":
         dtype = resolve_block_dtype(config, "auto")
-    else:
-        raise ValueError('get_block_size() expects location to be "memory" or "disk"')
+        bytes_per_value = torch.finfo(dtype).bits // 8
 
-    return round(n_params * torch.finfo(dtype).bits // 8 * (1 + eps))
+    return round(n_params * bytes_per_value * (1 + eps))

+ 15 - 15
src/petals/server/server.py

@@ -28,7 +28,7 @@ from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
-from petals.utils.convert_block import check_device_balance, convert_block
+from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.version import get_compatible_model_repo
 
@@ -75,7 +75,7 @@ class Server:
         mean_balance_check_period: float = 120,
         mean_block_selection_delay: float = 2.5,
         use_auth_token: Optional[str] = None,
-        load_in_8bit: Optional[bool] = None,
+        quant_type: Optional[QuantType] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
         dht_client_mode: Optional[bool] = None,
@@ -154,8 +154,8 @@ class Server:
             device = torch.device(device.type, index=0)
         self.device = device
 
-        torch_dtype = DTYPE_MAP[torch_dtype]
-        self.torch_dtype = resolve_block_dtype(self.block_config, torch_dtype)
+        torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
+        self.torch_dtype = torch_dtype
 
         if tensor_parallel_devices is None:
             tensor_parallel_devices = (device,)
@@ -164,10 +164,10 @@ class Server:
             logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
             check_device_balance(self.tensor_parallel_devices)
 
-        if load_in_8bit is None:
-            load_in_8bit = device.type == "cuda"
-        self.load_in_8bit = load_in_8bit
-        logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
+        if quant_type is None:
+            quant_type = QuantType.INT8 if device.type == "cuda" else QuantType.NONE
+        self.quant_type = quant_type
+        logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
 
         cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
         self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
@@ -203,7 +203,7 @@ class Server:
                 device,
                 torch_dtype,
                 num_blocks=num_blocks,
-                load_in_8bit=load_in_8bit,
+                quant_type=quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
@@ -237,11 +237,11 @@ class Server:
         else:
             total_memory = torch.cuda.get_device_properties(self.device).total_memory
 
-        block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
+        block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
 
-        # The estimates below are for bigscience/bloom-petals, serving as an upper bound for other models
         gib = 1024**3
-        autograd_memory = 2 * gib * num_devices  # GPU memory used for intermediate tensors in rpc_backward
+        # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
+        autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
 
         num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
         assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
@@ -284,7 +284,7 @@ class Server:
                 sender_threads=self.sender_threads,
                 revision=self.revision,
                 use_auth_token=self.use_auth_token,
-                load_in_8bit=self.load_in_8bit,
+                quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
                 start=True,
@@ -377,7 +377,7 @@ class ModuleContainer(threading.Thread):
         expiration: Optional[float],
         revision: Optional[str],
         use_auth_token: Optional[str],
-        load_in_8bit: bool,
+        quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
         **kwargs,
@@ -411,7 +411,7 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
+                block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,

+ 12 - 12
src/petals/server/throughput.py

@@ -13,7 +13,7 @@ from hivemind.utils.logging import get_logger
 from transformers import PretrainedConfig
 
 from petals.server.block_utils import resolve_block_dtype
-from petals.utils.convert_block import convert_block
+from petals.utils.convert_block import QuantType, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 logger = get_logger(__name__)
@@ -39,7 +39,7 @@ def get_server_throughput(
     dtype: Union[str, torch.dtype],
     *,
     num_blocks: int,
-    load_in_8bit: bool,
+    quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
@@ -60,7 +60,7 @@ def get_server_throughput(
 
         cache_key = f"model_{model_name}"
         cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
-        cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
+        cache_key += f"_dtype_{get_dtype_name(dtype, quant_type)}"
         if len(tensor_parallel_devices) > 1:
             for i, device_i in enumerate(tensor_parallel_devices):
                 cache_key += f"_tp{i}_{get_device_name(device_i).replace(' ', '_')}"
@@ -77,7 +77,7 @@ def get_server_throughput(
 
         if cache_key not in cache:
             cache[cache_key] = measure_throughput_info(
-                config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+                config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
             )
 
             try:
@@ -104,7 +104,7 @@ def measure_throughput_info(
     device: torch.device,
     dtype: torch.dtype,
     *,
-    load_in_8bit: bool,
+    quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
 ) -> Dict[str, float]:
     """Measure network and compute throughput in forward pass tokens per second"""
@@ -115,7 +115,7 @@ def measure_throughput_info(
 
     throughput_info = {
         "compute_rps": measure_compute_rps(
-            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+            config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
         )
     }
     try:
@@ -163,7 +163,7 @@ def measure_compute_rps(
     device: torch.device,
     dtype: torch.dtype,
     *,
-    load_in_8bit: bool,
+    quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
     n_tokens: int = 16,
     n_steps: int = 500,
@@ -172,7 +172,7 @@ def measure_compute_rps(
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
         block = config.block_class(config).to(dtype)
-        block = convert_block(block, config, tensor_parallel_devices, device, load_in_8bit=load_in_8bit, freeze=True)
+        block = convert_block(block, config, tensor_parallel_devices, device, quant_type=quant_type, freeze=True)
 
         cache = None
         elapsed = 0
@@ -192,7 +192,7 @@ def measure_compute_rps(
 
     logger.info(
         f"Forward pass throughput: {device_rps:.1f} RPS per block "
-        f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
+        f"({devices_repr}, {get_dtype_name(dtype, quant_type)})"
     )
     return device_rps
 
@@ -201,8 +201,8 @@ def get_device_name(device: torch.device) -> str:
     return f"{torch.cuda.get_device_name(device)} GPU" if device.type == "cuda" else "CPU"
 
 
-def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
+def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
     name = str(dtype)
-    if load_in_8bit:
-        name += ", 8-bit quantized"
+    if quant_type != QuantType.NONE:
+        name += f", quantized to {quant_type.name.lower()}"
     return name

+ 42 - 36
src/petals/utils/convert_block.py

@@ -3,6 +3,7 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 """
 import os
 import re
+from enum import Enum
 from typing import Sequence
 
 import tensor_parallel as tp
@@ -16,13 +17,18 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
+class QuantType(Enum):
+    NONE = 0
+    INT8 = 1  # 8-bit as in the LLM.int8() paper
+    NF4 = 2  # 4-bit as in the QLoRA paper
+
+
 def convert_block(
     block: nn.Module,
     config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
-    load_in_8bit: bool,
-    threshold: float = 6.0,
+    quant_type: QuantType,
     freeze: bool = True,
 ) -> tp.TensorParallel:
     """
@@ -34,20 +40,18 @@ def convert_block(
     :param tensor_parallel_devices: if specified, use tensor parallelism to split the model between these devices
     :note: if there is only a single device, model wil still be wrapped with TensorParallel (for uniformity)
     :param output_device: if tensor_parallel_devices is True, output
-    :param load_in_8bit: if True, use LLM.int8() quantization to reduce the model memory footprint
-    :param threshold: a quantization threshold from LLM.int8() paper ( https://arxiv.org/abs/2208.07339 )
+    :param quant_type: quantization type
     :param freeze: if True (default), make all module parameters non-trainable
     :return: a module that acts like the original block, but runs with all specified optimizations
 
     """
     if freeze:
-        for param in block.parameters():
-            param.requires_grad = False
+        block.requires_grad_(False)
 
     block = make_tensor_parallel(block, config, tensor_parallel_devices, output_device=output_device)
 
-    if load_in_8bit:
-        block = replace_8bit_linear(block, threshold=threshold)
+    if quant_type != QuantType.NONE:
+        block = quantize_module(block, quant_type=quant_type)
 
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
@@ -55,43 +59,45 @@ def convert_block(
     return block
 
 
-def replace_8bit_linear(model: nn.Module, threshold=6.0) -> nn.Module:
-    """
-    A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
-    library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
-    8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
-    version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
-    bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
-    The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
-    be kept as a `torch.nn.Linear` module.
-    Parameters:
-        model (`torch.nn.Module`):
-            Input model or `torch.nn.Module` as the function is run recursively.
-        threshold (`float`, *optional*):
-            `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
-            `6.0` as described by the paper.
-    """
-
+def quantize_module(model: nn.Module, *, quant_type: QuantType) -> nn.Module:
     # Import bitsandbytes only when necessary, so Petals runs on platforms not supported by bitsandbytes
     os.environ["BITSANDBYTES_NOWELCOME"] = "1"
     import bitsandbytes as bnb
 
     for n, module in model.named_children():
         if len(list(module.children())) > 0:
-            replace_8bit_linear(module, threshold)
+            quantize_module(module, quant_type=quant_type)
 
         if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
             assert module.weight.device.type == "cpu", f"expected linear layers on CPU, got {module.weight.device}"
-            model._modules[n] = bnb.nn.Linear8bitLt(
-                module.in_features,
-                module.out_features,
-                module.bias is not None,
-                has_fp16_weights=False,
-                threshold=threshold,
-            )
-            model._modules[n].weight = bnb.nn.Int8Params(
-                module.weight.data, requires_grad=False, has_fp16_weights=False
-            ).to(module.weight.dtype)
+            if quant_type == QuantType.INT8:
+                model._modules[n] = bnb.nn.Linear8bitLt(
+                    module.in_features,
+                    module.out_features,
+                    module.bias is not None,
+                    has_fp16_weights=False,
+                    threshold=6.0,  # Default from the LLM.int8() paper
+                )
+                model._modules[n].weight = bnb.nn.Int8Params(
+                    module.weight.data, requires_grad=False, has_fp16_weights=False
+                ).to(module.weight.dtype)
+            elif quant_type == QuantType.NF4:
+                compress_statistics = True
+                model._modules[n] = bnb.nn.LinearNF4(
+                    module.in_features,
+                    module.out_features,
+                    module.bias is not None,
+                    compress_statistics=compress_statistics,
+                )
+                model._modules[n].weight = bnb.nn.Params4bit(
+                    module.weight.data,
+                    requires_grad=False,
+                    quant_type="nf4",
+                    blocksize=64,
+                    compress_statistics=compress_statistics,
+                ).to(module.weight.dtype)
+            else:
+                raise ValueError(f"Unsupported quant_type='{quant_type}'")
             model._modules[n].bias = module.bias
     return model
 

+ 2 - 1
tests/test_aux_functions.py

@@ -3,6 +3,7 @@ import torch
 
 from petals import AutoDistributedConfig
 from petals.server.throughput import measure_compute_rps
+from petals.utils.convert_block import QuantType
 from test_utils import MODEL_NAME
 
 
@@ -15,7 +16,7 @@ def test_compute_throughput(tensor_parallel: bool):
         config,
         device=torch.device("cpu"),
         dtype=torch.bfloat16,
-        load_in_8bit=False,
+        quant_type=QuantType.NONE,
         tensor_parallel_devices=tensor_parallel_devices,
         n_steps=10,
     )

+ 4 - 1
tests/test_remote_sequential.py

@@ -78,7 +78,10 @@ class DummyCustomSequenceManager(RemoteSequenceManager):
         if protocol == "rpc_forward":
             metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
         elif protocol == "rpc_backward":
-            metadata["output_compression"] = (runtime_pb2.CompressionType.BLOCKWISE_8BIT,)
+            metadata["output_compression"] = (runtime_pb2.CompressionType.FLOAT16,)
+            # FIXME: Initially, we used CompressionType.BLOCKWISE_8BIT for rpc_backward() here.
+            # This is currently broken since hivemind==1.1.8 is not compatible with bitsandbytes==0.39.1.
+            # Please revert to BLOCKWISE_8BIT once this is fixed: https://github.com/learning-at-home/hivemind/issues/572
         return metadata