justheuristic 3 years ago
parent
commit
b6f3bbfd97
1 changed files with 18 additions and 12 deletions
  1. 18 12
      cli/convert_model.py

+ 18 - 12
cli/convert_model.py

@@ -21,17 +21,21 @@ if __name__ == "__main__":
     parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
     parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
     parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
-    parser.add_argument("--output_path", type=str, default='./converted_model', help="Track output repo to this folder")
-    parser.add_argument("--output_repo", type=str, default='bigscience/test-bloomd', help="Push to this HF hub repo")
-    parser.add_argument("--base_branch", type=str, default='main', help="Use this branch as reference point")
-    parser.add_argument("--client_branch", type=str, default='client', help="Save client version to this branch")
-    parser.add_argument("--block_branch_prefix", type=str, default='block_', help="Save blocks to branches with this prefix")
-    parser.add_argument("--commit_message", type=str, default='push-o-matic', help="Use this commit message for all parts")
+    parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
+    parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
+    parser.add_argument("--base_branch", type=str, default="main", help="Use this branch as reference point")
+    parser.add_argument("--client_branch", type=str, default="client", help="Save client version to this branch")
+    parser.add_argument(
+        "--block_branch_prefix", type=str, default="block_", help="Save blocks to branches with this prefix"
+    )
+    parser.add_argument(
+        "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
+    )
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
     args = parser.parse_args()
 
-    free_ram_gb = psutil.virtual_memory().available / 2 ** 30
-    if free_ram_gb < 400:
+    free_ram_gb = psutil.virtual_memory().available / 2**30
+    if args.model == "bigscience/bloom" and free_ram_gb < 400:
         logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
 
     assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
@@ -40,7 +44,7 @@ if __name__ == "__main__":
     ):
         raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
 
-    logger.info(f"Loading source model {args.model}")
+    logger.info(f"Loading source model {args.model} (this may take a few minutes)")
     model = transformers.AutoModelForCausalLM.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
@@ -53,12 +57,14 @@ if __name__ == "__main__":
     repo.git_pull()
 
     transformer_blocks = model.transformer.h
-    logger.info(f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
-                f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}")
+    logger.info(
+        f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
+        f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
+    )
     for i, block in enumerate(tqdm(transformer_blocks)):
         repo.git_checkout(args.base_branch, create_branch_ok=True)
         with repo.commit(
-                commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
+            commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
         ):
             torch.save(block.state_dict(), "./pytorch_model.bin")