run_trainer.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python
  2. import os
  3. from pathlib import Path
  4. import transformers
  5. from transformers import HfArgumentParser
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. from lib.training.hf_trainer import CollaborativeHFTrainer
  8. import callback
  9. import utils
  10. from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments
  11. from task import TrainingTask
  12. transformers.utils.logging.set_verbosity_warning()
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger(__name__)
  15. def main():
  16. parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
  17. training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  18. logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
  19. # if len(training_peer_args.initial_peers) == 0:
  20. # logger.warning("Please specify at least one network endpoint in initial peers.")
  21. utils.log_process_rank(trainer_args)
  22. task = TrainingTask(training_peer_args, trainer_args, collab_args)
  23. model = task.model.to(trainer_args.device)
  24. collaborative_callback = callback.CollaborativeCallback(task, training_peer_args)
  25. assert trainer_args.do_train and not trainer_args.do_eval
  26. # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
  27. # This is done because collaborative training has its own callbacks that take other peers into account.
  28. trainer = CollaborativeHFTrainer(
  29. model=model,
  30. args=trainer_args,
  31. tokenizer=task.tokenizer,
  32. data_collator=task.data_collator,
  33. data_seed=hash(task.local_public_key),
  34. train_dataset=task.training_dataset,
  35. eval_dataset=None,
  36. collaborative_optimizer=task.collaborative_optimizer,
  37. callbacks=[collaborative_callback],
  38. )
  39. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  40. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  41. latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
  42. trainer.train(model_path=latest_checkpoint_dir)
  43. if __name__ == "__main__":
  44. main()