Quellcode durchsuchen

Merge remote-tracking branch 'origin/main' into efficient-forward-backward

dbaranchuk vor 3 Jahren
Ursprung
Commit
b60eedc8ad
3 geänderte Dateien mit 17 neuen und 14 gelöschten Zeilen
  1. 6 4
      .github/workflows/run-tests.yaml
  2. 1 1
      src/bloom/ops.py
  3. 10 9
      tests/test_full_model.py

+ 6 - 4
.github/workflows/run-tests.yaml

@@ -28,12 +28,12 @@ jobs:
           pip install -r requirements.txt
       - name: Delete previous model, if exists
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
           name='test-bloomd-350m-$HF_TAG', organization='bloom-testing')" || true
       - name: Convert model and push to hub
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           python -m cli.convert_model --model bigscience/bloom-350m  --output_path ./converted_model \
             --output_repo bloom-testing/test-bloomd-350m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
 
@@ -64,7 +64,7 @@ jobs:
           pip install -r requirements-dev.txt
       - name: Test
         run: |
-          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
+          export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
           export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
           export REF_NAME=bigscience/bloom-350m
 
@@ -72,6 +72,8 @@ jobs:
             --torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
           SERVER1_PID=$!
           
+          sleep 5  # wait for the first server to initialize DHT
+          
           export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
           # ^-- server 1 multiaddr is determined by --identity and --host_maddrs
           
@@ -79,7 +81,7 @@ jobs:
             --torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
           SERVER2_PID=$!
 
-          sleep 30  # wait for server to download layers
+          sleep 60  # wait for server to download layers
           
           PYTHONPATH=. pytest tests
           

+ 1 - 1
src/bloom/ops.py

@@ -101,7 +101,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
         attention_mask: ([`torch.tensor`], *required*):
             attention mask to pre-process
     """
-    assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
+    assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
     unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
     # ^-- [batch, max_len], values correspond to element indices after removing padding
     # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0

+ 10 - 9
tests/test_full_model.py

@@ -13,13 +13,15 @@ logger = get_logger(__file__)
 @pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
     assert isinstance(model, DistributedBloomForCausalLM)
     assert len(model.transformer.h) == model.config.n_layer
 
     test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
-    with torch.no_grad():
+    with torch.inference_mode():
         parallel_outputs = model.forward(test_inputs).logits
         assert torch.all(torch.isfinite(parallel_outputs))
         logger.info("Forward outputs are finite")
@@ -32,21 +34,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
                 recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
         recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
-
-        dictionary = model.transformer.word_embeddings.weight.t()
-        recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
-        recurrent_outputs = (recurrent_outputs @ dictionary).float()
+        recurrent_outputs = model.lm_head(recurrent_outputs)
         assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
         logger.info("Inference is consistent with forward")
 
-        del model, recurrent_outputs
+        del model, embs, recurrent_outputs
 
         if REF_NAME:
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
+            ref_model = transformers.BloomForCausalLM.from_pretrained(
+                REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+            )
             dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
             # note: this creates a dummy mask to make the test compatible with older transformer versions
             # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
+            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
             assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
             logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
             del ref_model, ref_outputs, dummy_mask