5
0

test_speculative_generation.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import random
  2. import pytest
  3. import torch
  4. import transformers
  5. from petals import AutoDistributedModelForCausalLM
  6. from petals import AutoDistributedConfig, RemoteSequential
  7. from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
  8. from petals.server.from_pretrained import load_pretrained_block
  9. from test_utils import *
  10. @pytest.fixture
  11. def tokenizer():
  12. # We set use_fast=False since LlamaTokenizerFast is slow on load
  13. return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
  14. @pytest.fixture
  15. def model():
  16. return AutoDistributedModelForCausalLM.from_pretrained(
  17. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
  18. )
  19. @pytest.fixture
  20. def model2():
  21. return transformers.AutoModelForCausalLM.from_pretrained(
  22. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  23. )
  24. @pytest.fixture
  25. def ref_model():
  26. return transformers.AutoModelForCausalLM.from_pretrained(
  27. MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  28. )
  29. # @pytest.mark.forked
  30. # def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
  31. # config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  32. # remote_sequential = RemoteSequential(config)
  33. # block_index = random.randint(0, config.num_hidden_layers - 1)
  34. # remote_block = remote_sequential[block_index]
  35. # inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
  36. # short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
  37. # short_inputs[:, :2, :] = inputs[:, :2, :]
  38. # initial_outputs_inference = None
  39. # secondary_outputs_inference = None
  40. # with torch.inference_mode():
  41. # with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
  42. # initial_outputs_inference = sess.step(inputs)
  43. # secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
  44. # result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
  45. # ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
  46. # (outputs_local,) = ref_block(short_inputs)
  47. # assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
  48. # @pytest.mark.forked
  49. # def test_speculative_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
  50. # inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  51. # options = dict(max_new_tokens=max_new_tokens, do_sample=False)
  52. # outputs = model.generate(inputs, **options)
  53. # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@", outputs.shape, outputs)
  54. # ref_outputs = ref_model.generate(inputs, **options)
  55. # assert torch.allclose(
  56. # outputs, ref_outputs
  57. # ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
  58. @pytest.mark.forked
  59. def test_speculative_greedy_generation(tokenizer, model, model2, ref_model, max_new_tokens=50, batch_size=10):
  60. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  61. generated_ids = inputs
  62. with torch.no_grad():
  63. while generated_ids.shape[1] < max_new_tokens + inputs.shape[1]:
  64. outputs2 = model2.generate(generated_ids, max_new_tokens=batch_size, do_sample=False)
  65. new_tokens = outputs2[:, -batch_size:]
  66. random_pos = random.randrange(1, batch_size)
  67. new_tokens[:, random_pos] = random.randrange(1, 100)
  68. combined_ids = torch.cat((generated_ids, new_tokens), dim=1)
  69. logits = model(combined_ids).logits
  70. # Найти первую позицию, где токены совпали
  71. match_length = 0
  72. for i in range(batch_size):
  73. top_predicted_id_model2 = new_tokens[:, i]
  74. top_predicted_id_model = torch.argmax(logits[:, generated_ids.shape[1] + i - 1, :], dim=-1)
  75. if top_predicted_id_model2 == top_predicted_id_model:
  76. match_length += 1
  77. else:
  78. break
  79. print(f"Принято {match_length} из {batch_size}")
  80. if match_length > 0:
  81. generated_ids = torch.cat((generated_ids, new_tokens[:, :match_length]), dim=1)
  82. print(f"Всего {generated_ids.shape[1]}")
  83. else:
  84. break
  85. ref_outputs = ref_model.generate(inputs, max_new_tokens=max_new_tokens, do_sample=False)
  86. gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  87. ref_text = tokenizer.decode(ref_outputs[0], skip_special_tokens=True)
  88. print(f"Generated by speculative decoding: {gen_text}")
  89. print(f"Reference generation: {ref_text}")
  90. assert gen_text == ref_text, "The outputs do not match!"