|
@@ -1,4 +1,3 @@
|
|
-import peft
|
|
|
|
import pytest
|
|
import pytest
|
|
import torch
|
|
import torch
|
|
import transformers
|
|
import transformers
|
|
@@ -67,6 +66,8 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
|
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
)
|
|
)
|
|
if use_peft:
|
|
if use_peft:
|
|
|
|
+ import peft
|
|
|
|
+
|
|
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
|
|
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
|
|
ref_model.train(False)
|
|
ref_model.train(False)
|
|
if config.vocab_size < ref_model.config.vocab_size:
|
|
if config.vocab_size < ref_model.config.vocab_size:
|