|
@@ -39,7 +39,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
del model, recurrent_outputs
|
|
|
|
|
|
if REF_NAME:
|
|
|
- ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME, torch_dtype="auto")
|
|
|
+ ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
|
|
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
# prior to https://github.com/huggingface/transformers/pull/17837
|