run_trainer.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. #!/usr/bin/env python
  2. import os
  3. from pathlib import Path
  4. import torch
  5. import transformers
  6. from transformers import HfArgumentParser
  7. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  8. from lib.training.hf_trainer import CollaborativeHFTrainer
  9. import callback
  10. import utils
  11. from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments
  12. from task import TrainingTask
  13. transformers.utils.logging.set_verbosity_warning()
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger(__name__)
  16. torch.set_num_threads(1) # Otherwise, it becomes very slow on machines with ~100 CPUs
  17. def main():
  18. parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
  19. training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  20. logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
  21. utils.log_process_rank(trainer_args)
  22. task = TrainingTask(training_peer_args, trainer_args, collab_args)
  23. model = task.model.to(trainer_args.device)
  24. collaborative_callback = callback.CollaborativeCallback(task, training_peer_args)
  25. assert trainer_args.do_train and not trainer_args.do_eval
  26. # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
  27. # This is done because collaborative training has its own callbacks that take other peers into account.
  28. trainer = CollaborativeHFTrainer(
  29. model=model,
  30. args=trainer_args,
  31. tokenizer=task.tokenizer,
  32. data_collator=task.data_collator,
  33. data_seed=hash(task.local_public_key),
  34. train_dataset=task.training_dataset,
  35. eval_dataset=None,
  36. collaborative_optimizer=task.collaborative_optimizer,
  37. callbacks=[collaborative_callback],
  38. )
  39. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  40. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  41. latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
  42. trainer.train(model_path=latest_checkpoint_dir)
  43. if __name__ == "__main__":
  44. main()