test_full_model.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import peft
  2. import pytest
  3. import torch
  4. import transformers
  5. from hivemind import get_logger
  6. from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
  7. from petals import AutoDistributedModelForCausalLM
  8. from test_utils import *
  9. logger = get_logger(__name__)
  10. @pytest.fixture
  11. def tokenizer():
  12. # We set use_fast=False since LlamaTokenizerFast is slow on load
  13. return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
  14. @pytest.mark.forked
  15. @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
  16. @pytest.mark.parametrize("pass_empty_tensors", (True, False))
  17. def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
  18. model = AutoDistributedModelForCausalLM.from_pretrained(
  19. MODEL_NAME,
  20. initial_peers=INITIAL_PEERS,
  21. torch_dtype=torch.float32,
  22. active_adapter=ADAPTER_NAME if use_peft else None,
  23. )
  24. config = model.config
  25. assert len(model.transformer.h) == model.config.num_hidden_layers
  26. test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
  27. with torch.inference_mode():
  28. parallel_outputs = model.forward(test_inputs).logits
  29. assert torch.all(torch.isfinite(parallel_outputs))
  30. logger.info("Forward outputs are finite")
  31. embs = model.transformer.word_embeddings(test_inputs)
  32. embs = model.transformer.word_embeddings_layernorm(embs)
  33. recurrent_outputs = []
  34. with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
  35. if pass_empty_tensors:
  36. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  37. for t in range(embs.shape[1]):
  38. if t == 4:
  39. recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
  40. elif 4 < t < 9:
  41. continue
  42. else:
  43. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  44. if t == 2 and pass_empty_tensors:
  45. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  46. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  47. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  48. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  49. recurrent_outputs = model.lm_head(recurrent_outputs)
  50. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  51. logger.info("Inference is consistent with forward")
  52. del model, embs, recurrent_outputs
  53. if REF_NAME:
  54. ref_model = transformers.AutoModelForCausalLM.from_pretrained(
  55. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  56. )
  57. if use_peft:
  58. ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
  59. ref_model.train(False)
  60. if config.vocab_size < ref_model.config.vocab_size:
  61. ref_model.resize_token_embeddings(config.vocab_size)
  62. logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
  63. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  64. # note: this creates a dummy mask to make the test compatible with older transformer versions
  65. # prior to https://github.com/huggingface/transformers/pull/17837
  66. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
  67. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  68. logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
  69. del ref_model, ref_outputs, dummy_mask
  70. else:
  71. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  72. assert False
  73. @pytest.mark.forked
  74. def test_greedy_generation(tokenizer, max_new_tokens=4):
  75. model = AutoDistributedModelForCausalLM.from_pretrained(
  76. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
  77. )
  78. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  79. remote_outputs = model.generate(
  80. inputs,
  81. max_new_tokens=max_new_tokens,
  82. )
  83. hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
  84. assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
  85. if tokenizer.pad_token_id is None:
  86. tokenizer.pad_token_id = tokenizer.eos_token_id
  87. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  88. "input_ids"
  89. ]
  90. remote_outputs_batch = model.generate(
  91. inputs_batch,
  92. max_new_tokens=max_new_tokens,
  93. )
  94. hf_outputs_batch = HfGenerationMixin.greedy_search(
  95. model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
  96. )
  97. assert torch.allclose(
  98. remote_outputs_batch, hf_outputs_batch
  99. ), "Greedy search results are not identical to HF in multibatch mode"
  100. @pytest.mark.forked
  101. @pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
  102. @pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
  103. def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
  104. torch.manual_seed(0)
  105. model = AutoDistributedModelForCausalLM.from_pretrained(
  106. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
  107. )
  108. logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
  109. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  110. with torch.random.fork_rng():
  111. remote_outputs = model.generate(
  112. inputs,
  113. max_new_tokens=max_new_tokens,
  114. do_sample=True,
  115. **sampling_options,
  116. )
  117. with torch.random.fork_rng():
  118. hf_outputs = HfGenerationMixin.sample(
  119. model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
  120. )
  121. assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
  122. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  123. "input_ids"
  124. ]
  125. with torch.random.fork_rng():
  126. remote_outputs_batch = model.generate(
  127. inputs_batch,
  128. max_new_tokens=max_new_tokens,
  129. do_sample=True,
  130. **sampling_options,
  131. )
  132. with torch.random.fork_rng():
  133. hf_outputs_batch = HfGenerationMixin.sample(
  134. model,
  135. input_ids=inputs_batch,
  136. max_length=inputs_batch.size(1) + max_new_tokens,
  137. logits_warper=logits_warper,
  138. )
  139. assert torch.allclose(
  140. remote_outputs_batch, hf_outputs_batch
  141. ), "Sampling results are not identical to HF in multibatch mode"
  142. @pytest.mark.forked
  143. def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
  144. model = AutoDistributedModelForCausalLM.from_pretrained(
  145. MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
  146. )
  147. text = "A cat sat on a mat"
  148. inputs = tokenizer(text, return_tensors="pt")["input_ids"]
  149. remote_outputs = model.generate(
  150. inputs,
  151. max_new_tokens=max_new_tokens,
  152. num_beams=num_beams,
  153. )
  154. beam_scorer = BeamSearchScorer(
  155. batch_size=inputs.size(0),
  156. num_beams=num_beams,
  157. device=inputs.device,
  158. length_penalty=0,
  159. do_early_stopping=False,
  160. )
  161. hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
  162. hf_outputs = HfGenerationMixin.beam_search(
  163. model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
  164. )
  165. assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"