|
@@ -87,7 +87,7 @@ def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_
|
|
|
new_tokens[:, random_pos] = random.randrange(1, 100)
|
|
|
|
|
|
combined_ids = torch.cat((generated_ids, new_tokens), dim=1)
|
|
|
- logits = model(combined_ids, start_from_position=1).logits
|
|
|
+ logits = model(combined_ids).logits
|
|
|
|
|
|
# Найти первую позицию, где токены совпали
|
|
|
match_length = 0
|