Răsfoiți Sursa

push converted model to hub

justheuristic 3 ani în urmă
părinte
comite
736f1d1085
2 a modificat fișierele cu 35 adăugiri și 20 ștergeri
  1. 2 1
      README.md
  2. 33 19
      cli/quantize_cpu_naive.py

+ 2 - 1
README.md

@@ -17,9 +17,10 @@ conda activate bloom-demo
 
 conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
 pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
+pip install accelerate==0.10.0 huggingface-hub==0.7.0
 pip install bitsandbytes-cuda113==0.26.0
 pip install https://github.com/learning-at-home/hivemind/archive/master.zip
-pip install https://github.com/huggingface/transformers/archive/224bde91caff4ccfd12277ab5e9bf97c61e22ee9.zip
+pip install https://github.com/huggingface/transformers/archive/6589e510fa4e6c442059de2fab84752535de9b23.zip
 ```
 
 

+ 33 - 19
cli/quantize_cpu_naive.py

@@ -1,12 +1,13 @@
 import argparse
-import copy
 import os
 
 import psutil
 import torch.backends.quantized
 import transformers
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from tqdm.auto import trange
+from huggingface_hub import Repository
+import torch.nn as nn
+from tqdm.auto import tqdm
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
@@ -16,16 +17,22 @@ DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.f
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
-    parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
-    parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
+
+    parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
     parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
     parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
+    parser.add_argument("--output_path", type=str, default='./converted_model', help="Track output repo to this folder")
+    parser.add_argument("--output_repo", type=str, default='bigscience/test-bloomd', help="Push to this HF hub repo")
+    parser.add_argument("--base_branch", type=str, default='main', help="Use this branch as reference point")
+    parser.add_argument("--client_branch", type=str, default='client', help="Save client version to this branch")
+    parser.add_argument("--block_branch_prefix", type=str, default='block_', help="Save blocks to branches with this prefix")
+    parser.add_argument("--commit_message", type=str, default='push-o-matic', help="Use this commit message for all parts")
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
     args = parser.parse_args()
 
-    free_ram_gb = psutil.virtual_memory().available / 2**30
+    free_ram_gb = psutil.virtual_memory().available / 2 ** 30
     if free_ram_gb < 400:
-        logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
+        logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
 
     assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
     if os.path.exists(args.output_path) and (
@@ -33,21 +40,28 @@ if __name__ == "__main__":
     ):
         raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
 
-    model = transformers.BloomForCausalLM.from_pretrained(
+    model = transformers.AutoModelForCausalLM.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
-
-    qconfig = torch.quantization.get_default_qconfig("fbgemm")
-    torch.backends.quantized.engine = "fbgemm"
-
+    tokenizer = transformers.AutoTokenizer.from_pretrained(
+        args.model, use_auth_token=args.use_auth_token, revision=args.revision
+    )
     os.makedirs(args.output_path, exist_ok=True)
 
-    for i in trange(len(model.transformer.h)):
-        layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
-        layer_quantized = torch.quantization.quantize_dynamic(
-            layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
-        )
-        torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
+    repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
+    repo.git_pull()
+
+    transformer_blocks = model.transformer.h
+    for i, block in enumerate(tqdm(transformer_blocks)):
+        repo.git_checkout(args.base_branch, create_branch_ok=True)
+        with repo.commit(
+                commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
+        ):
+            print(block.self_attention.layer_number)
+            torch.save(block.state_dict(), "./pytorch_model.bin")
 
-    model.transformer.h = torch.nn.ModuleList()
-    torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))
+    repo.git_checkout(args.base_branch, create_branch_ok=True)
+    with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
+        model.transformer.h = nn.ModuleList()
+        model.save_pretrained(".")
+    logger.info(f"Converted {args.model} and saved to {args.output_repo}")