Ver código fonte

Set per_device_batch_sizes = 3

Aleksandr Borzunov 3 anos atrás
pai
commit
259e6b9009
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      arguments.py

+ 2 - 2
arguments.py

@@ -9,8 +9,8 @@ from transformers import TrainingArguments
 class HFTrainerArguments(TrainingArguments):
     """Arguments for huggingface/transformers.Trainer"""
     dataloader_num_workers: int = 1
-    per_device_train_batch_size: int = 2
-    per_device_eval_batch_size: int = 2
+    per_device_train_batch_size: int = 3
+    per_device_eval_batch_size: int = 3
     gradient_accumulation_steps: int = 1
     text_seq_length: int = 256