Selaa lähdekoodia

compare logits to logits

justheuristic 3 vuotta sitten
vanhempi
commit
a6c4a606e0
1 muutettua tiedostoa jossa 4 lisäystä ja 1 poistoa
  1. 4 1
      tests/test_full_model.py

+ 4 - 1
tests/test_full_model.py

@@ -36,7 +36,10 @@ def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
 
     if REF_NAME:
         ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
-        ref_outputs = ref_model.forward(test_inputs, attention_mask=torch.ones_like(test_inputs, dtype=torch.bool))
+        dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
+        # note: this creates a dummy mask to make the test compatible with older transformer versions
+        # prior to https://github.com/huggingface/transformers/pull/17837
+        ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
         assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
     else:
         logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")