|
@@ -3,6 +3,7 @@
|
|
import os
|
|
import os
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
+import torch
|
|
import transformers
|
|
import transformers
|
|
from transformers import HfArgumentParser
|
|
from transformers import HfArgumentParser
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
@@ -19,6 +20,8 @@ transformers.utils.logging.set_verbosity_warning()
|
|
use_hivemind_log_handler("in_root_logger")
|
|
use_hivemind_log_handler("in_root_logger")
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
+torch.set_num_threads(min(torch.get_num_threads(), 4)) # Otherwise, it becomes very slow on machines with ~100 CPUs
|
|
|
|
+
|
|
|
|
|
|
def main():
|
|
def main():
|
|
parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
|
|
parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
|