|
@@ -228,6 +228,9 @@ def main():
|
|
endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
|
|
endpoint=collaboration_args_dict.pop('endpoint'), record_validators=validators)
|
|
|
|
|
|
total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
|
total_batch_size_per_step = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
|
|
|
|
+ if torch.cuda.device_count() != 0:
|
|
|
|
+ total_batch_size_per_step *= torch.cuda.device_count()
|
|
|
|
+
|
|
statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
|
|
statistics_expiration = collaboration_args_dict.pop('statistics_expiration')
|
|
adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
|
|
adjusted_target_batch_size = collaboration_args_dict.pop('target_batch_size') \
|
|
- collaboration_args_dict.pop('batch_size_lead')
|
|
- collaboration_args_dict.pop('batch_size_lead')
|