Explorar o código

Try removing OffloadOptimizer

Aleksandr Borzunov %!s(int64=3) %!d(string=hai) anos
pai
achega
e97e7b8811
Modificáronse 2 ficheiros con 3 adicións e 5 borrados
  1. 2 2
      arguments.py
  2. 1 3
      task.py

+ 2 - 2
arguments.py

@@ -9,8 +9,8 @@ from transformers import TrainingArguments
 class HFTrainerArguments(TrainingArguments):
 class HFTrainerArguments(TrainingArguments):
     """Arguments for huggingface/transformers.Trainer"""
     """Arguments for huggingface/transformers.Trainer"""
     dataloader_num_workers: int = 1
     dataloader_num_workers: int = 1
-    per_device_train_batch_size: int = 3
-    per_device_eval_batch_size: int = 3
+    per_device_train_batch_size: int = 2
+    per_device_eval_batch_size: int = 2
     gradient_accumulation_steps: int = 1
     gradient_accumulation_steps: int = 1
     text_seq_length: int = 256
     text_seq_length: int = 256
 
 

+ 1 - 3
task.py

@@ -17,7 +17,6 @@ from arguments import HFTrainerArguments, BasePeerArguments, CollaborativeArgume
 from data import make_dataset
 from data import make_dataset
 from huggingface_auth import authorize_with_huggingface
 from huggingface_auth import authorize_with_huggingface
 from lib.training.clipped_lamb import LambWithGradientClipping
 from lib.training.clipped_lamb import LambWithGradientClipping
-from lib.training.offload import OffloadOptimizer
 
 
 
 
 logger = hivemind.get_logger(__name__)
 logger = hivemind.get_logger(__name__)
@@ -142,9 +141,8 @@ class TrainingTask:
             },
             },
         ]
         ]
 
 
-        opt = OffloadOptimizer(
+        opt = LambWithGradientClipping(
             optimizer_grouped_parameters,
             optimizer_grouped_parameters,
-            optim_cls=LambWithGradientClipping,
             lr=training_args.learning_rate,
             lr=training_args.learning_rate,
             betas=(training_args.adam_beta1, training_args.adam_beta2),
             betas=(training_args.adam_beta1, training_args.adam_beta2),
             eps=training_args.adam_epsilon,
             eps=training_args.adam_epsilon,