5
0

test_aux_functions.py 484 B

123456789101112131415
  1. import pytest
  2. import torch
  3. from test_utils import MODEL_NAME
  4. from petals.client import DistributedBloomConfig
  5. from petals.server.throughput import measure_compute_rps
  6. @pytest.mark.forked
  7. def test_throughput_basic():
  8. config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
  9. throughput = measure_compute_rps(
  10. config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
  11. )
  12. assert isinstance(throughput, float) and throughput > 0