1234567891011121314151617 |
- import pytest
- import torch
- from test_utils import MODEL_NAME
- from petals.client import DistributedBloomConfig
- from petals.server.throughput import measure_compute_rps, measure_network_rps
- @pytest.mark.forked
- def test_throughput_basic():
- config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
- compute_rps = measure_compute_rps(
- config, device=torch.device("cpu"), dtype=torch.bfloat16, load_in_8bit=False, n_steps=10
- )
- assert isinstance(compute_rps, float) and compute_rps > 0
- network_rps = measure_network_rps(config)
- assert isinstance(network_rps, float) and network_rps > 0
|