5
0

benchmark_forward.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #!/usr/bin/env python3
  2. import argparse
  3. import multiprocessing as mp
  4. from time import perf_counter
  5. import numpy as np
  6. import torch
  7. from hivemind.utils.logging import get_logger
  8. from petals import AutoDistributedModel
  9. from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
  10. logger = get_logger()
  11. def main():
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--model", type=str, default="bigscience/bloom")
  14. parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
  15. parser.add_argument("--torch_dtype", type=str, default="bfloat16")
  16. parser.add_argument("--n_processes", type=str, default=1)
  17. parser.add_argument("--seq_len", type=int, default=128)
  18. parser.add_argument("--n_steps", type=int, default=100)
  19. parser.add_argument("--batch_size", type=int, required=True)
  20. parser.add_argument("--warmup_steps", type=int, default=1)
  21. args = parser.parse_args()
  22. if args.n_processes == "n_gpus":
  23. args.n_processes = torch.cuda.device_count()
  24. else:
  25. args.n_processes = int(args.n_processes)
  26. processes = [mp.Process(target=benchmark_forward, args=(i, args)) for i in range(args.n_processes)]
  27. for proc in processes:
  28. proc.start()
  29. for proc in processes:
  30. proc.join()
  31. @torch.inference_mode()
  32. def benchmark_forward(process_idx, args):
  33. model = AutoDistributedModel.from_pretrained(
  34. args.model,
  35. initial_peers=args.initial_peers,
  36. torch_dtype=DTYPE_MAP[args.torch_dtype],
  37. )
  38. logger.info(f"Created model: {process_idx=} {model.device=}")
  39. torch.manual_seed(42)
  40. step_times = []
  41. for step in range(args.warmup_steps + args.n_steps):
  42. start_time = perf_counter()
  43. input_ids = torch.randint(0, model.config.vocab_size, size=(args.batch_size, args.seq_len))
  44. logger.info(f"{process_idx=} Fwd begin {input_ids.shape=}")
  45. h = model(input_ids)
  46. # We don't use model.lm_head
  47. logger.info(f"{process_idx=} Fwd end")
  48. if step >= args.warmup_steps:
  49. step_times.append(perf_counter() - start_time)
  50. speed = input_ids.numel() / np.mean(step_times)
  51. logger.info(f"{process_idx=} {step=} {speed=:.2f}")
  52. logger.info(f"Final result: {process_idx=} {speed=:.2f}")
  53. if __name__ == "__main__":
  54. main()