Bladeren bron

CI: Update deprecated actions, don't measure network RPS (#215)

* CI: Switch to actions/cache@v3 (v2 is deprecated)
* Don't run measure_network_rps() in tests since it doesn't work well in
CI
Alexander Borzunov 2 jaren geleden
bovenliggende
commit
702bb5a2c2
2 gewijzigde bestanden met toevoegingen van 4 en 6 verwijderingen
  1. 3 3
      .github/workflows/run-tests.yaml
  2. 1 3
      tests/test_aux_functions.py

+ 3 - 3
.github/workflows/run-tests.yaml

@@ -16,7 +16,7 @@ jobs:
         uses: actions/checkout@v2
       - name: Check if the model is cached
         id: cache-model
-        uses: actions/cache@v2
+        uses: actions/cache@v3
         with:
           path: ~/.dummy
           key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
@@ -27,7 +27,7 @@ jobs:
           python-version: 3.9
       - name: Cache dependencies
         if: steps.cache-model.outputs.cache-hit != 'true'
-        uses: actions/cache@v2
+        uses: actions/cache@v3
         with:
           path: ~/.cache/pip
           key: Key-v1-3.9-${{ hashFiles('setup.cfg') }}
@@ -70,7 +70,7 @@ jobs:
         with:
           python-version: ${{ matrix.python-version }}
       - name: Cache dependencies
-        uses: actions/cache@v2
+        uses: actions/cache@v3
         with:
           path: ~/.cache/pip
           key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}

+ 1 - 3
tests/test_aux_functions.py

@@ -8,7 +8,7 @@ from petals.server.throughput import measure_compute_rps, measure_network_rps
 
 @pytest.mark.forked
 @pytest.mark.parametrize("tensor_parallel", [False, True])
-def test_throughput_basic(tensor_parallel: bool):
+def test_compute_throughput(tensor_parallel: bool):
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
     tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
@@ -20,5 +20,3 @@ def test_throughput_basic(tensor_parallel: bool):
         n_steps=10,
     )
     assert isinstance(compute_rps, float) and compute_rps > 0
-    network_rps = measure_network_rps(config)
-    assert isinstance(network_rps, float) and network_rps > 0