瀏覽代碼

add quantization script for cpu

justheuristic 3 年之前
父節點
當前提交
05faa0b3c8
共有 4 個文件被更改,包括 76 次插入16 次删除
  1. 0 0
      cli/__init__.py
  2. 49 0
      cli/quantize_for_cpu.py
  3. 20 12
      src/block.py
  4. 7 4
      src/model.py

+ 0 - 0
cli/__init__.py


+ 49 - 0
cli/quantize_for_cpu.py

@@ -0,0 +1,49 @@
+import argparse
+import copy
+import os
+
+import psutil
+import torch.backends.quantized
+import transformers
+from hivemind.utils.logging import get_logger
+from tqdm.auto import trange
+
+logger = get_logger(__file__)
+
+DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
+    parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
+    parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
+    parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
+    parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
+    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
+    args = parser.parse_args()
+
+    free_ram_gb = psutil.virtual_memory().available / 2**30
+    if free_ram_gb < 400:
+        logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
+
+    assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
+    if os.path.exists(args.output_path) and (
+        len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
+    ):
+        raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
+
+    model = transformers.BloomForCausalLM.from_pretrained(
+        args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
+    )
+
+    qconfig = torch.quantization.get_default_qconfig("fbgemm")
+    torch.backends.quantized.engine = "fbgemm"
+
+    os.makedirs(args.output_path, exist_ok=True)
+
+    for i in trange(len(model.transformer.h)):
+        layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
+        layer_quantized = torch.quantization.quantize_dynamic(
+            layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
+        )
+        torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))

+ 20 - 12
src/block.py

@@ -6,11 +6,17 @@ See commit history for authorship.
 import math
 
 import torch
-import torch.nn.quantized.dynamic.modules.linear
 import torch.nn as nn
+import torch.nn.quantized.dynamic.modules.linear
 
-from src.ops import BloomScaledSoftmax, BloomGelu
-from src.ops import attention_mask_func, pre_process_alibi_for_pad, split_tensor_along_last_dim, dropout_add
+from src.ops import (
+    BloomGelu,
+    BloomScaledSoftmax,
+    attention_mask_func,
+    dropout_add,
+    pre_process_alibi_for_pad,
+    split_tensor_along_last_dim,
+)
 
 
 class BloomAttention(nn.Module):
@@ -43,11 +49,13 @@ class BloomAttention(nn.Module):
             self.layer_number,
         )
 
-        if config.compression == 'qint8':
+        if config.compression == "qint8":
             self.query_key_value = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8)
+                self.hidden_size, 3 * self.hidden_size, bias_=True, dtype=torch.qint8
+            )
             self.dense = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
+                self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
+            )
         else:
             self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
             self.dense = nn.Linear(self.hidden_size, self.hidden_size)
@@ -120,9 +128,7 @@ class BloomAttention(nn.Module):
 
         # attention scores and attention mask [b, np, sq, sk]
         max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
-        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(
-            value_layer.dtype
-        )
+        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
         attention_probs = self.attention_dropout(attention_probs)
 
         if head_mask is not None:
@@ -170,11 +176,13 @@ class BloomMLP(nn.Module):
     def __init__(self, config):
         super().__init__()
         self.hidden_size = config.hidden_size
-        if config.compression == 'qint8':
+        if config.compression == "qint8":
             self.dense_h_to_4h = nn.quantized.dynamic.modules.Linear(
-                self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8)
+                self.hidden_size, 4 * self.hidden_size, bias_=True, dtype=torch.qint8
+            )
             self.dense_4h_to_h = nn.quantized.dynamic.modules.Linear(
-                4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8)
+                4 * self.hidden_size, self.hidden_size, bias_=True, dtype=torch.qint8
+            )
         else:
             self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
             self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)

+ 7 - 4
src/model.py

@@ -10,12 +10,15 @@ import torch
 import torch.utils.checkpoint
 from torch import nn
 from torch.nn import CrossEntropyLoss, LayerNorm
-
-from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from transformers.file_utils import (
+    add_code_sample_docstrings,
+    add_start_docstrings,
+    add_start_docstrings_to_model_forward,
+)
 from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
 from transformers.modeling_utils import PreTrainedModel
-from transformers.utils import logging
 from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
+from transformers.utils import logging
 
 from src.block import BloomBlock
 from src.ops import build_alibi_tensor
@@ -28,7 +31,7 @@ _TOKENIZER_FOR_DOC = "BloomTokenizer"
 
 
 class MemoryEfficientBloomConfig(_VanillaBloomConfig):
-    compression: str = 'none'
+    compression: str = "none"
     slow_but_exact: bool = False