test_aux_functions.py 858 B

123456789101112131415161718192021222324
  1. import pytest
  2. import torch
  3. from test_utils import MODEL_NAME
  4. from petals.client import DistributedBloomConfig
  5. from petals.server.throughput import measure_compute_rps, measure_network_rps
  6. @pytest.mark.forked
  7. @pytest.mark.parametrize("tensor_parallel", [False, True])
  8. def test_throughput_basic(tensor_parallel: bool):
  9. config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
  10. tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
  11. compute_rps = measure_compute_rps(
  12. config,
  13. device=torch.device("cpu"),
  14. dtype=torch.bfloat16,
  15. load_in_8bit=False,
  16. tensor_parallel_devices=tensor_parallel_devices,
  17. n_steps=10,
  18. )
  19. assert isinstance(compute_rps, float) and compute_rps > 0
  20. network_rps = measure_network_rps(config)
  21. assert isinstance(network_rps, float) and network_rps > 0