benchmark_inference.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #!/usr/bin/env python3
  2. import argparse
  3. import multiprocessing as mp
  4. from time import perf_counter
  5. import numpy as np
  6. import torch
  7. from hivemind.utils.logging import get_logger
  8. from transformers import AutoTokenizer
  9. from petals import AutoDistributedModelForCausalLM
  10. from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
  11. logger = get_logger()
  12. def main():
  13. parser = argparse.ArgumentParser()
  14. parser.add_argument("--model", type=str, default="bigscience/bloom")
  15. parser.add_argument("--initial_peers", type=str, nargs="+", default=PUBLIC_INITIAL_PEERS)
  16. parser.add_argument("--torch_dtype", type=str, default="bfloat16")
  17. parser.add_argument("--n_processes", type=str, default=1)
  18. parser.add_argument("--seq_len", type=int, default=2048)
  19. parser.add_argument("--warmup_steps", type=int, default=1)
  20. args = parser.parse_args()
  21. if args.n_processes == "n_gpus":
  22. args.n_processes = torch.cuda.device_count()
  23. else:
  24. args.n_processes = int(args.n_processes)
  25. processes = [mp.Process(target=benchmark_inference, args=(i, args)) for i in range(args.n_processes)]
  26. for proc in processes:
  27. proc.start()
  28. for proc in processes:
  29. proc.join()
  30. @torch.inference_mode()
  31. def benchmark_inference(process_idx, args):
  32. tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
  33. # Using use_fast=False since LlamaTokenizerFast takes a long time to start, and we decode 1 token at a time anyway
  34. model = AutoDistributedModelForCausalLM.from_pretrained(
  35. args.model, initial_peers=args.initial_peers, torch_dtype=DTYPE_MAP[args.torch_dtype]
  36. )
  37. logger.info(f"Created model: {process_idx=} {model.device=}")
  38. result = ""
  39. step_times = []
  40. with model.transformer.h.inference_session(max_length=args.seq_len) as sess:
  41. for step in range(args.seq_len):
  42. start_time = perf_counter()
  43. outputs = model.generate(max_new_tokens=1, session=sess)
  44. result += tokenizer.decode(outputs[0])
  45. if step >= args.warmup_steps:
  46. step_times.append(perf_counter() - start_time)
  47. speed = 1 / np.mean(step_times)
  48. logger.info(f"{process_idx=} {step=} {speed=:.2f}")
  49. logger.info(f"Final result: {process_idx=} {speed=:.2f}")
  50. if __name__ == "__main__":
  51. main()