test_aux_functions.py 778 B

1234567891011121314151617181920212223
  1. import pytest
  2. import torch
  3. from petals import AutoDistributedConfig
  4. from petals.server.throughput import measure_compute_rps
  5. from petals.utils.convert_block import QuantType
  6. from test_utils import MODEL_NAME
  7. @pytest.mark.forked
  8. @pytest.mark.parametrize("tensor_parallel", [False, True])
  9. def test_compute_throughput(tensor_parallel: bool):
  10. config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
  11. tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
  12. compute_rps = measure_compute_rps(
  13. config,
  14. device=torch.device("cpu"),
  15. dtype=torch.bfloat16,
  16. quant_type=QuantType.NONE,
  17. tensor_parallel_devices=tensor_parallel_devices,
  18. n_steps=10,
  19. )
  20. assert isinstance(compute_rps, float) and compute_rps > 0