benchmark_inference.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python3
  2. import argparse
  3. import multiprocessing as mp
  4. from time import perf_counter
  5. import torch
  6. from hivemind.utils.logging import get_logger
  7. from transformers import AutoTokenizer
  8. from petals import AutoDistributedModelForCausalLM
  9. from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
  10. logger = get_logger()
  11. def main():
  12. parser = argparse.ArgumentParser()
  13. parser.add_argument("--model", type=str, default="bigscience/bloom")
  14. parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
  15. parser.add_argument("--torch_dtype", type=str, default="bfloat16")
  16. parser.add_argument("--n_processes", type=str, default=1)
  17. parser.add_argument("--seq_len", type=int, default=2048)
  18. parser.add_argument("--warmup_steps", type=int, default=1)
  19. args = parser.parse_args()
  20. if args.n_processes == "n_gpus":
  21. args.n_processes = torch.cuda.device_count()
  22. else:
  23. args.n_processes = int(args.n_processes)
  24. processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
  25. for proc in processes:
  26. proc.start()
  27. for proc in processes:
  28. proc.join()
  29. @torch.inference_mode()
  30. def benchmark_inference(process_idx, args):
  31. tokenizer = AutoTokenizer.from_pretrained(args.model)
  32. model = AutoDistributedModelForCausalLM.from_pretrained(
  33. args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
  34. )
  35. logger.info(f"Created model: {process_idx=} {model.device=} {model.config.torch_dtype=}")
  36. result = ""
  37. with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
  38. for step in range(args.seq_len):
  39. if step == args.warmup_steps:
  40. start_time = perf_counter()
  41. outputs = model.generate(max_new_tokens=1, session=sess)
  42. result += tokenizer.decode(outputs[0])
  43. if step >= args.warmup_steps:
  44. speed = step / (perf_counter() - start_time)
  45. logger.info(f"{process_idx=} {step=} {speed=:.3f}")
  46. logger.info(f"Final result: {process_idx=} {speed=:.3f}")
  47. if __name__ == "__main__":
  48. main()