瀏覽代碼

run test inside CI

justheuristic 3 年之前
父節點
當前提交
0acfc84cad
共有 2 個文件被更改,包括 54 次插入28 次删除
  1. 50 26
      .github/workflows/run-tests.yaml
  2. 4 2
      tests/test_block_exact_match.py

+ 50 - 26
.github/workflows/run-tests.yaml

@@ -32,29 +32,53 @@ jobs:
             --output_repo bloom-testing/test-bloomd-350m   --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
 
 
-#  run-tests:
-#    runs-on: ubuntu-latest
-#    strategy:
-#      matrix:
-#        python-version: [ 3.7, 3.8, 3.9 ]
-#    timeout-minutes: 15
-#    steps:
-#      - uses: actions/checkout@v2
-#      - name: Set up Python
-#        uses: actions/setup-python@v2
-#        with:
-#          python-version: ${{ matrix.python-version }}
-#      - name: Cache dependencies
-#        uses: actions/cache@v2
-#        with:
-#          path: ~/.cache/pip
-#          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
-#      - name: Install dependencies
-#        run: |
-#          python -m pip install --upgrade pip
-#          pip install -r requirements.txt
-#          pip install -r requirements-dev.txt
-#      - name: Test
-#        run: |
-#          python -m cli.convert_model --model bigscience/bloom-350m  --output_path ./converted_model \
-#            --output_repo testing/test-bloomd-350m   --use_auth_token $MY_WRITE_TOKEN
+  run-tests:
+    runs-on: ubuntu-latest
+    strategy:
+      matrix:
+        python-version: [ 3.7, 3.8, 3.9 ]
+    timeout-minutes: 15
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: ${{ matrix.python-version }}
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-${{ matrix.python-version }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Test
+        run: |
+          # launch a DHT-only peer to connect others
+          hivemind-dht &> dht.log &
+          INITIAL_PID=$$
+          sleep 3
+          INITIAL_PEERS=$(python -c "with open('dht.log') as f: print(f.readlines()[1].split()[-1])" )
+          echo "Initial peer: ${INITIAL_PEERS}"
+          
+          python -m cli.run_server --converted_model_name_or_path bloom-testing/test-bloomd-350m \
+            --block_indices 0:12 --torch_dtype float32 --initial_peers $INITIAL_PEERS &> server1.log &
+          SERVER1_PID=$$
+          
+          sleep 60  # wait for server to download layers
+          # test individual blocks
+          export PYTHONPATH=. REF_NAME=bloom-testing/test-bloomd-350m
+          BLOCK_UID=bloom-testing/test-bloomd-350m.0 REF_INDEX=0 pytest tests/test_block_exact_match.py
+          BLOCK_UID=bloom-testing/test-bloomd-350m.0 REF_INDEX=2 pytest tests/test_block_exact_match.py
+          
+          kill $INITIAL_PID $SERVER1_PID
+          
+          #python -m cli.run_server --converted_model_name_or_path bloom-testing/test-bloomd-350m \
+          #  --block_indices 9:19 --torch_dtype float32 --initial_peers $INITIAL_PEERS &> server2.log &
+          #SERVER2_PID=$$
+          #
+          #python -m cli.run_server --converted_model_name_or_path bloom-testing/test-bloomd-350m \
+          #  --block_indices 18:24 --torch_dtype float32 --initial_peers $INITIAL_PEER &> server3.log &
+          #SERVER3_PID=$$

+ 4 - 2
tests/test_block_exact_match.py

@@ -3,6 +3,7 @@ import os
 
 import hivemind
 import torch
+import transformers
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
@@ -19,7 +20,7 @@ if not BLOCK_UID:
     raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
 
 REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID[-1].split(".")[-1]))
+REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
 
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
@@ -27,8 +28,9 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     remote_block = get_remote_module(dht, BLOCK_UID)
     assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
     assert isinstance(remote_block, RemoteTransformerBlock)
+    ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
 
-    inputs = torch.randn(1, 8, 4096)
+    inputs = torch.randn(1, 8, ref_config.hidden_size)
     (outputs_forward,) = remote_block(inputs)
 
     outputs_inference = []