run_trainer.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/env python
  2. import os
  3. from pathlib import Path
  4. import torch
  5. import transformers
  6. from transformers import HfArgumentParser
  7. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  8. from lib.training.hf_trainer import CollaborativeHFTrainer
  9. import callback
  10. import utils
  11. from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments
  12. from task import TrainingTask
  13. transformers.utils.logging.set_verbosity_warning()
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger(__name__)
  16. torch.set_num_threads(min(torch.get_num_threads(), 4)) # Otherwise, it becomes very slow on machines with ~100 CPUs
  17. def main():
  18. parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
  19. training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  20. logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
  21. # if len(training_peer_args.initial_peers) == 0:
  22. # logger.warning("Please specify at least one network endpoint in initial peers.")
  23. utils.log_process_rank(trainer_args)
  24. task = TrainingTask(training_peer_args, trainer_args, collab_args)
  25. model = task.model.to(trainer_args.device)
  26. collaborative_callback = callback.CollaborativeCallback(task, training_peer_args)
  27. assert trainer_args.do_train and not trainer_args.do_eval
  28. # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
  29. # This is done because collaborative training has its own callbacks that take other peers into account.
  30. trainer = CollaborativeHFTrainer(
  31. model=model,
  32. args=trainer_args,
  33. tokenizer=task.tokenizer,
  34. data_collator=task.data_collator,
  35. data_seed=hash(task.local_public_key),
  36. train_dataset=task.training_dataset,
  37. eval_dataset=None,
  38. collaborative_optimizer=task.collaborative_optimizer,
  39. callbacks=[collaborative_callback],
  40. )
  41. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  42. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  43. latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
  44. trainer.train(model_path=latest_checkpoint_dir)
  45. if __name__ == "__main__":
  46. main()