|
@@ -35,6 +35,8 @@ if __name__ == "__main__":
|
|
|
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
|
|
|
)
|
|
|
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
|
|
|
+ parser.add_argument("--resize_token_embeddings", type=int, default=None,
|
|
|
+ help="change the vocabulary size of the converted model to this value")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
free_ram_gb = psutil.virtual_memory().available / 2**30
|
|
@@ -56,6 +58,10 @@ if __name__ == "__main__":
|
|
|
model = BloomModel.from_pretrained(
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
|
|
|
)
|
|
|
+ if args.resize_token_embeddings:
|
|
|
+ logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
|
|
|
+ model.resize_token_embeddings(args.resize_token_embeddings)
|
|
|
+
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
|
|
args.model, use_auth_token=args.use_auth_token, revision=args.revision
|
|
|
)
|