xtinkt 1 سال پیش
والد
کامیت
5ca5822686
1فایلهای تغییر یافته به همراه101 افزوده شده و 18 حذف شده
  1. 101 18
      tests/test_speculative_generation.py

+ 101 - 18
tests/test_speculative_generation.py

@@ -3,33 +3,116 @@ import random
 import pytest
 import torch
 
+import transformers
+
+from petals import AutoDistributedModelForCausalLM
 from petals import AutoDistributedConfig, RemoteSequential
 from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
 from petals.server.from_pretrained import load_pretrained_block
 from test_utils import *
 
 
+@pytest.fixture
+def tokenizer():
+    # We set use_fast=False since LlamaTokenizerFast is slow on load
+    return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
+
+
+@pytest.fixture
+def model():
+    return AutoDistributedModelForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
+    )
+
+@pytest.fixture
+def model2():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+@pytest.fixture
+def ref_model():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+# @pytest.mark.forked
+# def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
+#     config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
+#     remote_sequential = RemoteSequential(config)
+
+#     block_index = random.randint(0, config.num_hidden_layers - 1)
+#     remote_block = remote_sequential[block_index]
+
+#     inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+#     short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
+#     short_inputs[:, :2, :] = inputs[:, :2, :]
+
+#     initial_outputs_inference = None
+#     secondary_outputs_inference = None
+#     with torch.inference_mode():
+#         with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
+#             initial_outputs_inference = sess.step(inputs)
+#             secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
+#             result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
+
+#     ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
+#     (outputs_local,) = ref_block(short_inputs)
+
+#     assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
+
+# @pytest.mark.forked
+# def test_speculative_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
+#     inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+#     options = dict(max_new_tokens=max_new_tokens, do_sample=False)
+#     outputs = model.generate(inputs, **options)
+#     print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", outputs.shape, outputs)
+#     ref_outputs = ref_model.generate(inputs, **options)
+#     assert torch.allclose(
+#         outputs, ref_outputs
+#     ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
+
 @pytest.mark.forked
-def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
-    config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
-    remote_sequential = RemoteSequential(config)
+def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_new_tokens=50, batch_size=10):    
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+    generated_ids = inputs
+
+    with torch.no_grad():
+        while generated_ids.shape[1] < max_new_tokens + inputs.shape[1]:
+            outputs2 = model2.generate(generated_ids, max_new_tokens=batch_size, do_sample=False)
+            new_tokens = outputs2[:, -batch_size:]
+
+            random_pos = random.randrange(1, batch_size)
+            new_tokens[:, random_pos] = random.randrange(1, 100)
 
-    block_index = random.randint(0, config.num_hidden_layers - 1)
-    remote_block = remote_sequential[block_index]
+            combined_ids = torch.cat((generated_ids, new_tokens), dim=1)
+            logits = model(combined_ids, start_from_position=1).logits
 
-    inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
-    short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
-    short_inputs[:, :2, :] = inputs[:, :2, :]
+            # Найти первую позицию, где токены совпали
+            match_length = 0
+            for i in range(batch_size):
+                top_predicted_id_model2 = new_tokens[:, i]
+                top_predicted_id_model = torch.argmax(logits[:, generated_ids.shape[1] + i - 1, :], dim=-1)
+                
+                if top_predicted_id_model2 == top_predicted_id_model:
+                    match_length += 1
+                else:
+                    break
+            print(f"Принято {match_length} из {batch_size}")
 
-    initial_outputs_inference = None
-    secondary_outputs_inference = None
-    with torch.inference_mode():
-        with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
-            initial_outputs_inference = sess.step(inputs)
-            secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
-            result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
+            if match_length > 0:
+                generated_ids = torch.cat((generated_ids, new_tokens[:, :match_length]), dim=1)
+                print(f"Всего {generated_ids.shape[1]}")
+            else:
+                break
+        
+        ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
+        
+    gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
+    ref_text = tokenizer.decode(ref_outputs[0], skip_special_tokens=True)
 
-    ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
-    (outputs_local,) = ref_block(short_inputs)
+    print(f"Generated by speculative decoding: {gen_text}")
+    print(f"Reference generation: {ref_text}")
 
-    assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
+    assert gen_text == ref_text, "The outputs do not match!"