Просмотр исходного кода

run all ./tests automatically (instead of calling each one manually)

justheuristic 3 лет назад
Родитель
Сommit
b737a60be4
3 измененных файлов с 31 добавлено и 38 удалено
  1. 3 9
      .github/workflows/run-tests.yaml
  2. 23 22
      tests/test_block_exact_match.py
  3. 5 7
      tests/test_chained_calls.py

+ 3 - 9
.github/workflows/run-tests.yaml

@@ -66,6 +66,8 @@ jobs:
         run: |
           export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
+          export REF_NAME=bigscience/bloom-350m
+
           python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
           SERVER1_PID=$!
@@ -79,15 +81,7 @@ jobs:
 
           sleep 30  # wait for server to download layers
           
-          # test individual blocks
-          export PYTHONPATH=.
-          BLOCK_UID=$MODEL_NAME.0 REF_NAME=$MODEL_NAME REF_INDEX=0 pytest tests/test_block_exact_match.py
-          BLOCK_UID=$MODEL_NAME.19 REF_NAME=$MODEL_NAME REF_INDEX=19 pytest tests/test_block_exact_match.py
-
-          REF_NAME=$MODEL_NAME pytest tests/test_chained_calls.py
-          
-          pytest tests/test_remote_sequential.py
-          REF_NAME=bigscience/bloom-350m pytest tests/test_full_model.py
+          PYTHONPATH=. pytest tests
           
           kill -s SIGINT $SERVER1_PID $SERVER2_PID
           echo "Done!"

+ 23 - 22
tests/test_block_exact_match.py

@@ -1,12 +1,15 @@
 # Note: this code is being actively modified by justheuristic. If you want to change anything about it, please warn me.
 import os
+import random
 
 import hivemind
+import pytest
 import torch
 import transformers
 
 from src.bloom.from_pretrained import load_pretrained_block
 from src.client.remote_block import RemoteTransformerBlock
+from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 
 INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
@@ -14,34 +17,32 @@ if not INITIAL_PEERS:
     raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
 INITIAL_PEERS = INITIAL_PEERS.split()
 
-
-BLOCK_UID = os.environ.get("BLOCK_UID")
-if not BLOCK_UID:
-    raise RuntimeError("Must specify BLOCK_UID as an index of a transformer block to be tested")
-
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-REF_INDEX = int(os.environ.get("REF_INDEX", BLOCK_UID.split(".")[-1]))
+MODEL_NAME = os.environ.get("MODEL_NAME")
+if not MODEL_NAME:
+    raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
 
 
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
+    config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
 
-    remote_block = get_remote_module(dht, BLOCK_UID)
-    assert remote_block is not None, f"Could not find {BLOCK_UID} in DHT"
-    assert isinstance(remote_block, RemoteTransformerBlock)
-    ref_config = transformers.AutoConfig.from_pretrained(REF_NAME)
+    for block_index in random.sample(range(config.n_layer), 3):
+        block_uid = f"{MODEL_NAME}{UID_DELIMITER}{block_index}"
+        remote_block = get_remote_module(dht, block_uid)
+        assert remote_block is not None, f"Could not find {block_uid} in DHT"
+        assert isinstance(remote_block, RemoteTransformerBlock)
 
-    inputs = torch.randn(1, 8, ref_config.hidden_size)
-    (outputs_forward,) = remote_block(inputs)
+        inputs = torch.randn(1, 8, config.hidden_size)
+        (outputs_forward,) = remote_block(inputs)
 
-    outputs_inference = []
-    with remote_block.inference_session() as sess:
-        for i in range(inputs.shape[1]):
-            outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
-    outputs_inference = torch.cat(outputs_inference, dim=1)
+        outputs_inference = []
+        with remote_block.inference_session() as sess:
+            for i in range(inputs.shape[1]):
+                outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
+        outputs_inference = torch.cat(outputs_inference, dim=1)
 
-    ref_block = load_pretrained_block(REF_NAME, REF_INDEX, torch_dtype=torch.float32)
-    (outputs_local,) = ref_block(inputs)
+        ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+        (outputs_local,) = ref_block(inputs)
 
-    assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
-    assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)
+        assert torch.allclose(outputs_local, outputs_forward, rtol=0, atol=atol_forward)
+        assert torch.allclose(outputs_local, outputs_inference, rtol=0, atol=atol_inference)

+ 5 - 7
tests/test_chained_calls.py

@@ -24,8 +24,6 @@ MODEL_NAME = os.environ.get("MODEL_NAME")
 if not MODEL_NAME:
     raise RuntimeError("Must specify MODEL_NAME as a name of a model to be tested")
 
-REF_NAME = os.environ.get("REF_NAME", "bigscience/test-bloomd-6b3")
-
 
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
@@ -38,9 +36,9 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
 
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 5, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 5, torch_dtype=torch.float32),
     ]
     inputs = torch.randn(1, seq_length, config.hidden_size, requires_grad=True)
     outputs_rpc = remote_block.forward(inputs)[0]
@@ -78,8 +76,8 @@ def test_chained_inference_exact_match(atol_inference=1e-4):
     outputs_inference = torch.cat(outputs_inference, dim=1)
 
     ref_blocks = [
-        load_pretrained_block(REF_NAME, 3, torch_dtype=torch.float32),
-        load_pretrained_block(REF_NAME, 4, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 3, torch_dtype=torch.float32),
+        load_pretrained_block(MODEL_NAME, 4, torch_dtype=torch.float32),
     ]
     outputs_ref = []
     caches = [None, None]