Browse Source

Account for multi-gpu devices in examples/albert (#309)

justheuristic 4 years ago
parent
commit
21fda756c1
1 changed files with 3 additions and 0 deletions
  1. 3 0
      examples/albert/run_trainer.py

+ 3 - 0
examples/albert/run_trainer.py

@@ -228,6 +228,9 @@ def main():
         endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
         endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
 
 
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
     total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
+    if torch.cuda.device_count() != 0:
+        total_batch_size_per_step *= torch.cuda.device_count()
+
     statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
     statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
     adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
     adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
                                  - collaboration_args_dict.pop('batch_size_lead')
                                  - collaboration_args_dict.pop('batch_size_lead')