Michael Diskin 4 年之前
父節點
當前提交
dcc6b02703
共有 1 個文件被更改,包括 1 次插入0 次删除
  1. 1 0
      examples/albert/run_trainer.py

+ 1 - 0
examples/albert/run_trainer.py

@@ -231,6 +231,7 @@ def main():
     tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
     model = get_model(training_args, config, tokenizer)
     model.to(training_args.device)
+    model.tie_weights()
 
     tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
     # This data collator will take care of randomly masking the tokens.