5
0

test_aux_functions.py 731 B

12345678910111213141516171819202122
  1. import pytest
  2. import torch
  3. from petals.client import DistributedBloomConfig
  4. from petals.server.throughput import measure_compute_rps
  5. from test_utils import MODEL_NAME
  6. @pytest.mark.forked
  7. @pytest.mark.parametrize("tensor_parallel", [False, True])
  8. def test_compute_throughput(tensor_parallel: bool):
  9. config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
  10. tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
  11. compute_rps = measure_compute_rps(
  12. config,
  13. device=torch.device("cpu"),
  14. dtype=torch.bfloat16,
  15. load_in_8bit=False,
  16. tensor_parallel_devices=tensor_parallel_devices,
  17. n_steps=10,
  18. )
  19. assert isinstance(compute_rps, float) and compute_rps > 0