run_trainer.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python
  2. import os
  3. from pathlib import Path
  4. import transformers
  5. from transformers import HfArgumentParser
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. from lib.training.hf_trainer import CollaborativeHFTrainer
  8. import callback
  9. import utils
  10. from arguments import TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments
  11. from task import TrainingTask
  12. use_hivemind_log_handler("in_root_logger")
  13. logger = get_logger()
  14. def main():
  15. parser = HfArgumentParser((TrainingPeerArguments, HFTrainerArguments, CollaborativeArguments))
  16. training_peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  17. logger.info(f"Trying {len(training_peer_args.initial_peers)} initial peers: {training_peer_args.initial_peers}")
  18. if len(training_peer_args.initial_peers) == 0:
  19. logger.warning("Please specify at least one network endpoint in initial peers.")
  20. utils.setup_logging(trainer_args)
  21. task = TrainingTask(training_peer_args, trainer_args, collab_args)
  22. model = task.model.to(trainer_args.device)
  23. collaborative_callback = callback.CollaborativeCallback(task, training_peer_args)
  24. assert trainer_args.do_train and not trainer_args.do_eval
  25. # Note: the code below creates the trainer with dummy scheduler and removes some callbacks.
  26. # This is done because collaborative training has its own callbacks that take other peers into account.
  27. trainer = CollaborativeHFTrainer(
  28. model=model,
  29. args=trainer_args,
  30. tokenizer=task.tokenizer,
  31. data_collator=task.data_collator,
  32. data_seed=hash(task.local_public_key),
  33. train_dataset=task.training_dataset,
  34. eval_dataset=None,
  35. collaborative_optimizer=task.collaborative_optimizer,
  36. callbacks=[collaborative_callback],
  37. )
  38. trainer.remove_callback(transformers.trainer_callback.PrinterCallback)
  39. trainer.remove_callback(transformers.trainer_callback.ProgressCallback)
  40. latest_checkpoint_dir = max(Path(trainer_args.output_dir).glob("checkpoint*"), key=os.path.getctime, default=None)
  41. trainer.train(model_path=latest_checkpoint_dir)
  42. if __name__ == "__main__":
  43. main()