test_full_model.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import peft
  2. import pytest
  3. import torch
  4. import transformers
  5. from hivemind import get_logger
  6. from petals import AutoDistributedModelForCausalLM
  7. from test_utils import *
  8. logger = get_logger(__name__)
  9. @pytest.fixture
  10. def tokenizer():
  11. # We set use_fast=False since LlamaTokenizerFast is slow on load
  12. return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
  13. @pytest.fixture
  14. def model():
  15. return AutoDistributedModelForCausalLM.from_pretrained(
  16. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
  17. )
  18. @pytest.fixture
  19. def ref_model():
  20. return transformers.AutoModelForCausalLM.from_pretrained(
  21. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  22. )
  23. @pytest.mark.forked
  24. @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
  25. @pytest.mark.parametrize("pass_empty_tensors", (True, False))
  26. def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
  27. if use_peft:
  28. model.config.active_adapter = ADAPTER_NAME
  29. ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
  30. ref_model.train(False)
  31. test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
  32. with torch.inference_mode():
  33. parallel_outputs = model.forward(test_inputs).logits
  34. assert torch.all(torch.isfinite(parallel_outputs))
  35. logger.info("Forward outputs are finite")
  36. embs = model.transformer.word_embeddings(test_inputs)
  37. embs = model.transformer.word_embeddings_layernorm(embs)
  38. recurrent_outputs = []
  39. with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
  40. if pass_empty_tensors:
  41. recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
  42. for t in range(embs.shape[1]):
  43. if t == 4:
  44. recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
  45. elif 4 < t < 9:
  46. continue
  47. else:
  48. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  49. if t == 2 and pass_empty_tensors:
  50. recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
  51. recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
  52. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  53. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  54. recurrent_outputs = model.lm_head(recurrent_outputs)
  55. assert torch.allclose(
  56. recurrent_outputs, parallel_outputs, rtol=0, atol=atol
  57. ), "Inference differs from forward pass"
  58. ref_outputs = ref_model.forward(test_inputs).logits.float()
  59. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
  60. def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
  61. if not multiple_calls:
  62. return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
  63. with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
  64. return torch.cat(
  65. [
  66. # Sessions provided both explicitly and implicitly should work
  67. model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
  68. model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
  69. model.generate(None, max_new_tokens=1, **kwargs),
  70. ],
  71. dim=1,
  72. )
  73. @pytest.mark.forked
  74. def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
  75. inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  76. if tokenizer.pad_token_id is None:
  77. tokenizer.pad_token_id = tokenizer.eos_token_id
  78. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  79. "input_ids"
  80. ]
  81. options = dict(max_new_tokens=max_new_tokens, do_sample=False)
  82. for multiple_calls in [False, True]:
  83. for inputs in [inputs_single, inputs_batch]:
  84. outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
  85. ref_outputs = ref_model.generate(inputs, **options)
  86. assert torch.allclose(
  87. outputs, ref_outputs
  88. ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
  89. @pytest.mark.forked
  90. def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
  91. inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  92. if tokenizer.pad_token_id is None:
  93. tokenizer.pad_token_id = tokenizer.eos_token_id
  94. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  95. "input_ids"
  96. ]
  97. for options in [
  98. dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
  99. dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
  100. ]:
  101. options.update(max_new_tokens=max_new_tokens)
  102. for multiple_calls in [False, True]:
  103. for inputs in [inputs_single, inputs_batch]:
  104. torch.manual_seed(0)
  105. outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
  106. torch.manual_seed(0)
  107. ref_outputs = ref_model.generate(inputs, **options)
  108. assert torch.allclose(
  109. outputs, ref_outputs
  110. ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
  111. @pytest.mark.skipif(
  112. "bloom" not in MODEL_NAME.lower(),
  113. reason="Mixtral and Llama use DynamicCache, which can change based on beam search choices",
  114. )
  115. @pytest.mark.forked
  116. def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
  117. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  118. options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
  119. outputs = make_generate_calls(model, inputs, **options)
  120. ref_outputs = ref_model.generate(inputs, **options)
  121. assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
  122. @pytest.mark.forked
  123. def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
  124. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
  125. assert inputs.keys() == {"input_ids", "attention_mask"}
  126. outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
  127. ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
  128. assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
  129. with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
  130. outputs = torch.cat(
  131. [
  132. model.generate(**inputs, max_new_tokens=2),
  133. model.generate(None, max_new_tokens=max_new_tokens - 2),
  134. ],
  135. dim=1,
  136. )
  137. assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"