Browse Source

refactoring

Dmitry Baranchuk 3 years ago
parent
commit
be83e6d0cb
2 changed files with 1 additions and 9 deletions
  1. 0 6
      .gitignore
  2. 1 3
      cli/convert_model.py

+ 0 - 6
.gitignore

@@ -1,9 +1,3 @@
-server_configs
-converted_model
-*.id
-*.ipynb
-*.out
-
 # Byte-compiled / optimized / DLL files
 __pycache__/
 *.py[cod]

+ 1 - 3
cli/convert_model.py

@@ -48,7 +48,6 @@ if __name__ == "__main__":
     config = transformers.AutoConfig.from_pretrained(
         args.model, use_auth_token=args.use_auth_token, revision=args.revision
     )
-    # model = transformers.AutoModelForCausalLM.from_pretrained(
     model = transformers.AutoModel.from_pretrained(    
         args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
     )
@@ -60,7 +59,7 @@ if __name__ == "__main__":
     repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
     repo.git_pull()
 
-    transformer_blocks = model.h #transformer.h
+    transformer_blocks = model.h
     logger.info(
         f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
         f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
@@ -76,7 +75,6 @@ if __name__ == "__main__":
     repo.git_checkout(args.base_branch, create_branch_ok=True)
     with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
         model.h = nn.ModuleList()
-        #model.transformer.h = nn.ModuleList()
         model.save_pretrained(".")
 
     logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")