Sfoglia il codice sorgente

push config and tokenizer separately

justheuristic 3 anni fa
parent
commit
6047a2ffe0
1 ha cambiato i file con 13 aggiunte e 1 eliminazioni
  1. 13 1
      cli/convert_model.py

+ 13 - 1
cli/convert_model.py

@@ -45,6 +45,9 @@ if __name__ == "__main__":
         raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
 
     logger.info(f"Loading source model {args.model} (this may take a few minutes)")
+    config = transformers.AutoConfig.from_pretrained(
+        args.model, use_auth_token=args.use_auth_token, revision=args.revision
+    )
     model = transformers.AutoModelForCausalLM.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
@@ -73,4 +76,13 @@ if __name__ == "__main__":
     with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
         model.transformer.h = nn.ModuleList()
         model.save_pretrained(".")
-    logger.info(f"Converted {args.model} and saved to {args.output_repo}")
+
+    logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
+
+    repo.git_checkout(args.base_branch, create_branch_ok=True)
+    with repo.commit(commit_message=args.commit_message, branch=args.base_branch, track_large_files=True):
+        tokenizer.save_pretrained(".")
+        config.save_pretrained(".")
+
+    logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
+