|
@@ -149,3 +149,23 @@ def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, n
|
|
|
outputs = make_generate_calls(model, inputs, **options)
|
|
|
ref_outputs = ref_model.generate(inputs, **options)
|
|
|
assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.forked
|
|
|
+def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
|
|
|
+ inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
|
|
|
+ assert inputs.keys() == {"input_ids", "attention_mask"}
|
|
|
+
|
|
|
+ outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
|
+ ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
|
|
|
+ assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
|
|
|
+
|
|
|
+ with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
|
|
|
+ outputs = torch.cat(
|
|
|
+ [
|
|
|
+ model.generate(**inputs, max_new_tokens=2),
|
|
|
+ model.generate(None, max_new_tokens=max_new_tokens - 2),
|
|
|
+ ],
|
|
|
+ dim=1,
|
|
|
+ )
|
|
|
+ assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
|