Forráskód Böngészése

Set num threads to <= 4

Aleksandr Borzunov 3 éve
szülő
commit
380a5f3231
2 módosított fájl, 5 hozzáadás és 0 törlés
  1. 3 0
      run_trainer.py
  2. 2 0
      run_trainer_tpu.py

+ 3 - 0
run_trainer.py

@@ -3,6 +3,7 @@
 import os
 from pathlib import Path
 
+import torch
 import transformers
 from transformers import HfArgumentParser
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -19,6 +20,8 @@ transformers.utils.logging.set_verbosity_warning()
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
+torch.set_num_threads(min(torch.get_num_threads(), 4))  # Otherwise, it becomes very slow on machines with ~100 CPUs
+
 
 def main():
     parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))

+ 2 - 0
run_trainer_tpu.py

@@ -20,6 +20,8 @@ logger = get_logger()
 
 transformers.training_args.is_torch_tpu_available = lambda: False  # disable builtin TPU support to use custom code
 
+torch.set_num_threads(min(torch.get_num_threads(), 4))  # Otherwise, it becomes very slow on machines with ~100 CPUs
+
 
 def main():
     parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments))