|
@@ -48,7 +48,6 @@ if __name__ == "__main__":
|
|
config = transformers.AutoConfig.from_pretrained(
|
|
config = transformers.AutoConfig.from_pretrained(
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
)
|
|
)
|
|
- # model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
|
|
model = transformers.AutoModel.from_pretrained(
|
|
model = transformers.AutoModel.from_pretrained(
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
)
|
|
)
|
|
@@ -60,7 +59,7 @@ if __name__ == "__main__":
|
|
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
|
|
repo.git_pull()
|
|
repo.git_pull()
|
|
|
|
|
|
- transformer_blocks = model.h #transformer.h
|
|
|
|
|
|
+ transformer_blocks = model.h
|
|
logger.info(
|
|
logger.info(
|
|
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
|
|
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
|
|
@@ -76,7 +75,6 @@ if __name__ == "__main__":
|
|
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
|
repo.git_checkout(args.base_branch, create_branch_ok=True)
|
|
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
|
|
model.h = nn.ModuleList()
|
|
model.h = nn.ModuleList()
|
|
- #model.transformer.h = nn.ModuleList()
|
|
|
|
model.save_pretrained(".")
|
|
model.save_pretrained(".")
|
|
|
|
|
|
logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
|
|
logger.info(f"Saving config and tokenizer to {args.output_repo}@{args.base_branch}")
|