benchmark_training.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. import argparse
  3. import multiprocessing as mp
  4. from time import perf_counter
  5. import numpy as np
  6. import torch
  7. from hivemind.utils.logging import get_logger
  8. from petals import AutoDistributedModelForCausalLM, AutoDistributedModelForSequenceClassification
  9. from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
  10. logger = get_logger()
  11. def main():
  12. parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  13. parser.add_argument("--model", type=str, required=True, help="Model")
  14. parser.add_argument("--device", type=str, default="cpu", help="Torch device hosting the client")
  15. parser.add_argument("--task", type=str, default="cls", help="Training task type")
  16. parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS, help="Initial peers")
  17. parser.add_argument("--torch_dtype", type=str, default="bfloat16", help="Torch dtype")
  18. parser.add_argument("--n_processes", type=str, default=1, help="Number of concurrent processes")
  19. parser.add_argument("--seq_len", type=int, default=128, help="Sequence length")
  20. parser.add_argument("--pre_seq_len", type=int, default=16, help="Number of trainable tokens")
  21. parser.add_argument("--n_steps", type=int, default=10, help="Number of benchmark steps")
  22. parser.add_argument("--batch_size", type=int, required=True, help="Batch size")
  23. parser.add_argument("--warmup_steps", type=int, default=1, help="Number of warmup steps")
  24. args = parser.parse_args()
  25. assert args.task in ["cls", "causal_lm"]
  26. if args.n_processes == "n_gpus":
  27. args.n_processes = torch.cuda.device_count()
  28. else:
  29. args.n_processes = int(args.n_processes)
  30. processes = [mp.Process(target=benchmark_training, args=(i, args)) for i in range(args.n_processes)]
  31. for proc in processes:
  32. proc.start()
  33. for proc in processes:
  34. proc.join()
  35. def benchmark_training(process_idx, args):
  36. if args.task == "cls":
  37. model = AutoDistributedModelForSequenceClassification.from_pretrained(
  38. args.model,
  39. initial_peers=args.initial_peers,
  40. torch_dtype=DTYPE_MAP[args.torch_dtype],
  41. tuning_mode="deep_ptune",
  42. pre_seq_len=args.pre_seq_len,
  43. num_labels=2,
  44. )
  45. elif args.task == "causal_lm":
  46. model = AutoDistributedModelForCausalLM.from_pretrained(
  47. args.model,
  48. initial_peers=args.initial_peers,
  49. torch_dtype=DTYPE_MAP[args.torch_dtype],
  50. tuning_mode="deep_ptune",
  51. pre_seq_len=args.pre_seq_len,
  52. )
  53. model = model.to(args.device)
  54. opt = torch.optim.Adam(model.parameters())
  55. logger.info(f"Created model: {process_idx=} {model.device=}")
  56. torch.manual_seed(42)
  57. fwd_times = []
  58. bwd_times = []
  59. for step in range(args.warmup_steps + args.n_steps):
  60. input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len), device=args.device)
  61. if args.task == "cls":
  62. labels = torch.randint(0, 2, size=[args.batch_size], device=args.device)
  63. else:
  64. labels = input_ids
  65. logger.info(f"{process_idx=} {step=} Forward")
  66. start_time = perf_counter()
  67. outputs = model(input_ids, labels=labels)
  68. if step >= args.warmup_steps:
  69. fwd_times.append(perf_counter() - start_time)
  70. logger.info(f"{process_idx=} {step=} Backward")
  71. start_time = perf_counter()
  72. outputs.loss.backward()
  73. if step >= args.warmup_steps:
  74. bwd_times.append(perf_counter() - start_time)
  75. logger.info(f"{process_idx=} {step=} Optimizer step")
  76. opt.step()
  77. opt.zero_grad()
  78. if step >= args.warmup_steps:
  79. fwd_speed = input_ids.numel() / np.mean(fwd_times)
  80. bwd_speed = input_ids.numel() / np.mean(bwd_times)
  81. logger.info(f"{process_idx=} Fwd speed: {fwd_speed:.2f} | Bwd speed: {bwd_speed:.2f}")
  82. logger.info(f"Final result: {process_idx=} {fwd_speed=:.2f} | {bwd_speed=:.2f}")
  83. if __name__ == "__main__":
  84. main()