Browse Source

fix hf beam_search generation

Artem Chumachenko 2 năm trước cách đây
mục cha
commit
5fbb97ab9e
1 tập tin đã thay đổi với 3 bổ sung2 xóa
  1. 3 2
      tests/test_full_model.py

+ 3 - 2
tests/test_full_model.py

@@ -93,7 +93,7 @@ def test_greedy_generation(max_new_tokens=4):
 
 
 @pytest.mark.forked
-def test_greedy_generation(max_new_tokens=4, num_beams=2):
+def test_beam_search_generation(max_new_tokens=4, num_beams=2):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
@@ -112,7 +112,8 @@ def test_greedy_generation(max_new_tokens=4, num_beams=2):
         do_early_stopping=False,
         num_beam_hyps_to_keep=2,
     )
+    hf_inputs = tokenizer(["A cat sat on a mat"] * 2, return_tensors="pt")["input_ids"]
     hf_outputs = BloomForCausalLM.beam_search(
-        model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
+        model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
     )
     assert torch.allclose(remote_outputs, hf_outputs), "Beam search are not identical to HF"