run_trainer_tpu.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. #!/usr/bin/env python
  2. import time
  3. import wandb
  4. import torch
  5. import transformers.training_args
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. from transformers import HfArgumentParser
  8. import utils
  9. from arguments import TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments
  10. from lib.training.tpu import TPUManager
  11. from callback import CollaborativeCallback
  12. from task import TrainingTask
  13. use_hivemind_log_handler("in_root_logger")
  14. logger = get_logger()
  15. transformers.training_args.is_torch_tpu_available = lambda: False # disable builtin TPU support to use custom code
  16. def main():
  17. parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments))
  18. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  19. logger.info(f"Found {len(peer_args.initial_peers)} initial peers: {peer_args.initial_peers}")
  20. if len(peer_args.initial_peers) == 0:
  21. logger.warning("Please specify at least one network endpoint in initial peers.")
  22. utils.setup_logging(trainer_args)
  23. task = TrainingTask(peer_args, trainer_args, collab_args)
  24. model = task.model
  25. # BEGIN init TPU
  26. assert trainer_args.do_train and not trainer_args.do_eval
  27. tpu_manager = TPUManager(model, dataset=task.training_dataset, collate_fn=task.data_collator,
  28. grad_accumulation_steps=trainer_args.gradient_accumulation_steps,
  29. batch_size_per_device=trainer_args.per_device_train_batch_size,
  30. nprocs=trainer_args.n_tpus, start=True)
  31. model = task.model = tpu_manager._synchronizer.master_model
  32. # warmup tpus
  33. logger.info("Waiting for TPUs to warm up, this may take a minute...")
  34. tpu_manager.step()
  35. logger.info("Warmup step 1 / 3 done.")
  36. tpu_manager.update_model_parameters(model.parameters())
  37. tpu_manager.step()
  38. logger.info("Warmup step 2 / 3 done.")
  39. tpu_manager.step()
  40. tpu_manager.get_aggregated_gradients()
  41. tpu_manager.zero_grad()
  42. logger.info("Warmup step 3 / 3 done.")
  43. # END init TPU
  44. def push_params_onto_tpu():
  45. logger.info("Pushing new params onto TPU.")
  46. tpu_manager.update_model_parameters(model.parameters())
  47. tpu_manager.zero_grad()
  48. collaborative_optimizer = task.collaborative_optimizer
  49. collaborative_optimizer.callbacks.on_after_global_step.add(push_params_onto_tpu)
  50. collaborative_optimizer.callbacks.on_load_state_from_peers(push_params_onto_tpu)
  51. collaborative_training_callback = CollaborativeCallback(task, peer_args)
  52. state = transformers.TrainerState()
  53. control = transformers.TrainerControl()
  54. collaborative_training_callback.on_train_begin(trainer_args, state, control)
  55. tpu_manager.update_model_parameters(model.parameters())
  56. wandb.init(project=trainer_args.wandb_project, name=trainer_args.run_name)
  57. while True:
  58. start_time = time.perf_counter()
  59. loss, num_accumulated = tpu_manager.step()
  60. time_delta = time.perf_counter() - start_time
  61. logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.")
  62. wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.scheduler.get_lr()[0]})
  63. with torch.no_grad():
  64. for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()):
  65. param.grad[...] = grad_from_tpu
  66. collaborative_optimizer.step()
  67. state.log_history.append(dict(loss=loss))
  68. collaborative_training_callback.on_step_end(trainer_args, state, control)
  69. if __name__ == "__main__":
  70. main()