|
@@ -93,7 +93,7 @@ def test_greedy_generation(max_new_tokens=4):
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
-def test_greedy_generation(max_new_tokens=4, num_beams=2):
|
|
|
+def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
@@ -112,7 +112,8 @@ def test_greedy_generation(max_new_tokens=4, num_beams=2):
|
|
|
do_early_stopping=False,
|
|
|
num_beam_hyps_to_keep=2,
|
|
|
)
|
|
|
+ hf_inputs = tokenizer(["A cat sat on a mat"] * 2, return_tensors="pt")["input_ids"]
|
|
|
hf_outputs = BloomForCausalLM.beam_search(
|
|
|
- model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
|
|
+ model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
|
|
)
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"
|