|
@@ -9,8 +9,8 @@ from transformers import TrainingArguments
|
|
class HFTrainerArguments(TrainingArguments):
|
|
class HFTrainerArguments(TrainingArguments):
|
|
"""Arguments for huggingface/transformers.Trainer"""
|
|
"""Arguments for huggingface/transformers.Trainer"""
|
|
dataloader_num_workers: int = 1
|
|
dataloader_num_workers: int = 1
|
|
- per_device_train_batch_size: int = 1
|
|
|
|
- per_device_eval_batch_size: int = 1
|
|
|
|
|
|
+ per_device_train_batch_size: int = 2
|
|
|
|
+ per_device_eval_batch_size: int = 2
|
|
gradient_accumulation_steps: int = 1
|
|
gradient_accumulation_steps: int = 1
|
|
text_seq_length: int = 256
|
|
text_seq_length: int = 256
|
|
|
|
|