convert_model.py 4.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import argparse
  2. import os
  3. import psutil
  4. import torch.backends.quantized
  5. import torch.nn as nn
  6. import transformers
  7. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  8. from huggingface_hub import Repository
  9. from tqdm.auto import tqdm
  10. use_hivemind_log_handler("in_root_logger")
  11. logger = get_logger(__file__)
  12. DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
  13. if __name__ == "__main__":
  14. parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
  15. parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
  16. parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
  17. parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
  18. parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
  19. parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
  20. parser.add_argument("--base_branch", type=str, default="main", help="Use this branch as reference point")
  21. parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
  22. parser.add_argument(
  23. "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
  24. )
  25. parser.add_argument(
  26. "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
  27. )
  28. parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
  29. args = parser.parse_args()
  30. free_ram_gb = psutil.virtual_memory().available / 2**30
  31. if args.model == "bigscience/bloom" and free_ram_gb < 400:
  32. logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
  33. assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
  34. if os.path.exists(args.output_path) and (
  35. len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
  36. ):
  37. raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
  38. logger.info(f"Loading source model {args.model} (this may take a few minutes)")
  39. config = transformers.AutoConfig.from_pretrained(
  40. args.model, use_auth_token=args.use_auth_token, revision=args.revision
  41. )
  42. # model = transformers.AutoModelForCausalLM.from_pretrained(
  43. model = transformers.AutoModel.from_pretrained(
  44. args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
  45. )
  46. tokenizer = transformers.AutoTokenizer.from_pretrained(
  47. args.model, use_auth_token=args.use_auth_token, revision=args.revision
  48. )
  49. os.makedirs(args.output_path, exist_ok=True)
  50. repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
  51. repo.git_pull()
  52. transformer_blocks = model.h #transformer.h
  53. logger.info(
  54. f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
  55. f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
  56. )
  57. for i, block in enumerate(tqdm(transformer_blocks)):
  58. repo.git_checkout(args.base_branch, create_branch_ok=True)
  59. with repo.commit(
  60. commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
  61. ):
  62. torch.save(block.state_dict(), "./pytorch_model.bin")
  63. logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
  64. repo.git_checkout(args.base_branch, create_branch_ok=True)
  65. with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
  66. model.h = nn.ModuleList()
  67. #model.transformer.h = nn.ModuleList()
  68. model.save_pretrained(".")
  69. logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
  70. repo.git_checkout(args.base_branch, create_branch_ok=True)
  71. with repo.commit(commit_message=args.commit_message, branch=args.base_branch, track_large_files=True):
  72. tokenizer.save_pretrained(".")
  73. config.save_pretrained(".")
  74. logger.info(f"Converted {args.model} and pushed to {args.output_repo}")