|
@@ -231,6 +231,7 @@ def main():
|
|
|
tokenizer = AlbertTokenizerFast.from_pretrained(dataset_args.tokenizer_path, cache_dir=dataset_args.cache_dir)
|
|
|
model = get_model(training_args, config, tokenizer)
|
|
|
model.to(training_args.device)
|
|
|
+ model.tie_weights()
|
|
|
|
|
|
tokenized_datasets = load_from_disk(Path(dataset_args.dataset_path))
|
|
|
# This data collator will take care of randomly masking the tokens.
|