run_trainer_tpu.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. #!/usr/bin/env python3
  2. import time
  3. import wandb
  4. import torch
  5. import transformers
  6. from hivemind.utils.logging import get_logger, use_hivemind_log_handler
  7. from transformers import HfArgumentParser
  8. import utils
  9. from arguments import TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments
  10. from lib.training.tpu import TPUManager
  11. from callback import CollaborativeCallback
  12. from task import TrainingTask
  13. transformers.utils.logging.set_verbosity_warning()
  14. use_hivemind_log_handler("in_root_logger")
  15. logger = get_logger()
  16. transformers.training_args.is_torch_tpu_available = lambda: False # disable builtin TPU support to use custom code
  17. torch.set_num_threads(min(torch.get_num_threads(), 4)) # Otherwise, it becomes very slow on machines with ~100 CPUs
  18. def main():
  19. parser = HfArgumentParser((TrainingPeerArguments, TPUTrainerArguments, CollaborativeArguments))
  20. peer_args, trainer_args, collab_args = parser.parse_args_into_dataclasses()
  21. logger.info(f"Found {len(peer_args.initial_peers)} initial peers: {peer_args.initial_peers}")
  22. if len(peer_args.initial_peers) == 0:
  23. logger.warning("Please specify at least one network endpoint in initial peers.")
  24. utils.log_process_rank(trainer_args)
  25. task = TrainingTask(peer_args, trainer_args, collab_args)
  26. model = task.model
  27. # BEGIN init TPU
  28. assert trainer_args.do_train and not trainer_args.do_eval
  29. tpu_manager = TPUManager(model, dataset=task.training_dataset, collate_fn=task.data_collator,
  30. grad_accumulation_steps=trainer_args.gradient_accumulation_steps,
  31. batch_size_per_device=trainer_args.per_device_train_batch_size,
  32. nprocs=trainer_args.n_tpus, start=True)
  33. model = task.model = tpu_manager._synchronizer.master_model
  34. # warmup tpus
  35. logger.info("Waiting for TPUs to warm up, this may take a minute...")
  36. tpu_manager.step()
  37. logger.info("Warmup step 1 / 3 done.")
  38. tpu_manager.update_model_parameters(model.parameters())
  39. tpu_manager.step()
  40. logger.info("Warmup step 2 / 3 done.")
  41. tpu_manager.step()
  42. tpu_manager.get_aggregated_gradients()
  43. tpu_manager.zero_grad()
  44. logger.info("Warmup step 3 / 3 done.")
  45. # END init TPU
  46. def push_params_onto_tpu():
  47. logger.info("Pushing new params onto TPU.")
  48. tpu_manager.update_model_parameters(model.parameters())
  49. tpu_manager.zero_grad()
  50. collaborative_optimizer = task.collaborative_optimizer
  51. collaborative_optimizer.callbacks.on_after_global_step.add(push_params_onto_tpu)
  52. collaborative_optimizer.callbacks.on_load_state_from_peers(push_params_onto_tpu)
  53. collaborative_training_callback = CollaborativeCallback(task, peer_args)
  54. state = transformers.TrainerState()
  55. control = transformers.TrainerControl()
  56. collaborative_training_callback.on_train_begin(trainer_args, state, control)
  57. tpu_manager.update_model_parameters(model.parameters())
  58. wandb.init(project=trainer_args.wandb_project, name=trainer_args.run_name)
  59. while True:
  60. start_time = time.perf_counter()
  61. loss, num_accumulated = tpu_manager.step()
  62. time_delta = time.perf_counter() - start_time
  63. logger.info(f"Accumulated {num_accumulated} gradients at {num_accumulated / time_delta:.3f} samples/second.")
  64. wandb.log({"train/loss": loss, "train/learning_rate": collaborative_optimizer.state_averager.scheduler.get_lr()[0]})
  65. with torch.no_grad():
  66. for param, grad_from_tpu in zip(model.parameters(), tpu_manager.get_aggregated_gradients()):
  67. param.grad[...] = grad_from_tpu
  68. collaborative_optimizer.step()
  69. state.log_history.append(dict(loss=loss))
  70. collaborative_training_callback.on_step_end(trainer_args, state, control)
  71. if __name__ == "__main__":
  72. main()