瀏覽代碼

CI: Convert model only when convert_model.py or setup.cfg change (#213)

This reduces the test running time by 2 times, unless convert_model.py or setup.cfg are changed.
Alexander Borzunov 2 年之前
父節點
當前提交
825f5dbf2d
共有 2 個文件被更改,包括 25 次插入7 次删除
  1. 20 6
      .github/workflows/run-tests.yaml
  2. 5 1
      src/petals/cli/convert_model.py

+ 20 - 6
.github/workflows/run-tests.yaml

@@ -12,32 +12,45 @@ jobs:
       BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }}
     timeout-minutes: 15
     steps:
-      - uses: actions/checkout@v2
+      - name: Checkout
+        uses: actions/checkout@v2
+      - name: Check if the model is cached
+        id: cache-model
+        uses: actions/cache@v2
+        with:
+          path: ~/.dummy
+          key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
       - name: Set up Python
+        if: steps.cache-model.outputs.cache-hit != 'true'
         uses: actions/setup-python@v2
         with:
           python-version: 3.9
       - name: Cache dependencies
+        if: steps.cache-model.outputs.cache-hit != 'true'
         uses: actions/cache@v2
         with:
           path: ~/.cache/pip
           key: Key-v1-3.9-${{ hashFiles('setup.cfg') }}
       - name: Install dependencies
+        if: steps.cache-model.outputs.cache-hit != 'true'
         run: |
           python -m pip install --upgrade pip
-          pip install .[dev]
+          pip install .
       - name: Delete any test models older than 1 week
+        if: steps.cache-model.outputs.cache-hit != 'true'
         run: |
           python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
       - name: Delete previous version of this model, if exists
+        if: steps.cache-model.outputs.cache-hit != 'true'
         run: |
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
           repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true
       - name: Convert model and push to hub
+        if: steps.cache-model.outputs.cache-hit != 'true'
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
-          python -m petals.cli.convert_model --model bigscience/bloom-560m  --output_path ./converted_model \
+          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
+          python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
             --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
             --resize_token_embeddings 50000
 
@@ -50,7 +63,8 @@ jobs:
       fail-fast: false
     timeout-minutes: 15
     steps:
-      - uses: actions/checkout@v2
+      - name: Checkout
+        uses: actions/checkout@v2
       - name: Set up Python
         uses: actions/setup-python@v2
         with:
@@ -66,7 +80,7 @@ jobs:
           pip install .[dev]
       - name: Test
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
+          export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }}
           export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG
           export REF_NAME=bigscience/bloom-560m
 

+ 5 - 1
src/petals/cli/convert_model.py

@@ -18,7 +18,7 @@ logger = get_logger(__file__)
 DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
 
 
-if __name__ == "__main__":
+def main():
     parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
 
     parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
@@ -90,3 +90,7 @@ if __name__ == "__main__":
         config.save_pretrained(".")
 
     logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
+
+
+if __name__ == "__main__":
+    main()