quantize_cpu_naive.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import argparse
  2. import copy
  3. import os
  4. import psutil
  5. import torch.backends.quantized
  6. import transformers
  7. from hivemind.utils.logging import get_logger
  8. from tqdm.auto import trange
  9. logger = get_logger(__file__)
  10. DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
  11. if __name__ == "__main__":
  12. parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
  13. parser.add_argument("--output_path", required=True, type=str, help="Save quantized layers to this folder")
  14. parser.add_argument("--model", type=str, default="bigscience/bloom", help="Model name for from_pretrained")
  15. parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
  16. parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
  17. parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
  18. args = parser.parse_args()
  19. free_ram_gb = psutil.virtual_memory().available / 2**30
  20. if free_ram_gb < 400:
  21. logger.warning(f"ACHTUNG! converting bloom-176b will use up 370-400GB RAM, you have {free_ram_gb:.3f} free")
  22. assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
  23. if os.path.exists(args.output_path) and (
  24. len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
  25. ):
  26. raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
  27. model = transformers.BloomForCausalLM.from_pretrained(
  28. args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
  29. )
  30. qconfig = torch.quantization.get_default_qconfig("fbgemm")
  31. torch.backends.quantized.engine = "fbgemm"
  32. os.makedirs(args.output_path, exist_ok=True)
  33. for i in trange(len(model.transformer.h)):
  34. layer_fp32 = copy.deepcopy(model.transformer.h[i]).float()
  35. layer_quantized = torch.quantization.quantize_dynamic(
  36. layer_fp32, {torch.nn.Linear: qconfig}, dtype=torch.qint8, inplace=True
  37. )
  38. torch.save(layer_quantized.state_dict(), os.path.join(args.output_path, f"block_{i}_qint8.pth"))
  39. model.transformer.h = torch.nn.ModuleList()
  40. torch.save(model.state_dict(), os.path.join(args.output_path, f"client.pth"))