test_aux_functions.py 616 B

1234567891011121314151617
  1. import pytest
  2. import torch
  3. from test_utils import MODEL_NAME
  4. from petals.client import DistributedBloomConfig
  5. from petals.server.throughput import measure_compute_rps, measure_network_rps
  6. @pytest.mark.forked
  7. def test_throughput_basic():
  8. config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
  9. compute_rps = measure_compute_rps(
  10. config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
  11. )
  12. assert isinstance(compute_rps, float) and compute_rps > 0
  13. network_rps = measure_network_rps(config)
  14. assert isinstance(network_rps, float) and network_rps > 0