test_full_model.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import pytest
  2. import torch
  3. import transformers
  4. from hivemind import get_logger
  5. from test_utils import *
  6. from transformers.generation import BeamSearchScorer
  7. from transformers.models.bloom import BloomForCausalLM
  8. from petals.client.remote_model import DistributedBloomForCausalLM
  9. logger = get_logger(__file__)
  10. @pytest.mark.forked
  11. @pytest.mark.parametrize("pass_empty_tensors", (True, False))
  12. @pytest.mark.parametrize("second_token_attention_mask", (1, 0))
  13. def test_full_model_exact_match(
  14. pass_empty_tensors: bool, second_token_attention_mask: int, atol_forward=1e-3, atol_inference=1e-3
  15. ):
  16. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  17. model = DistributedBloomForCausalLM.from_pretrained(
  18. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  19. )
  20. config = model.config
  21. assert isinstance(model, DistributedBloomForCausalLM)
  22. assert len(model.transformer.h) == model.config.n_layer
  23. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  24. attention_mask = torch.ones_like(test_inputs)
  25. attention_mask[0, 1] = second_token_attention_mask
  26. with torch.inference_mode():
  27. parallel_outputs = model.forward(test_inputs, attention_mask=attention_mask).logits
  28. assert torch.all(torch.isfinite(parallel_outputs))
  29. logger.info("Forward outputs are finite")
  30. embs = model.transformer.word_embeddings(test_inputs)
  31. embs = model.transformer.word_embeddings_layernorm(embs)
  32. recurrent_outputs = []
  33. with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
  34. if pass_empty_tensors:
  35. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  36. for t in range(embs.shape[1]):
  37. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :], attention_mask=attention_mask[:, : t + 1]))
  38. if t == int(embs.shape[1] // 2) and pass_empty_tensors:
  39. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  40. recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
  41. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  42. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  43. recurrent_outputs = model.lm_head(recurrent_outputs)
  44. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  45. logger.info("Inference is consistent with forward")
  46. del model, embs, recurrent_outputs
  47. if REF_NAME:
  48. ref_model = transformers.BloomForCausalLM.from_pretrained(
  49. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  50. )
  51. if config.vocab_size < ref_model.config.vocab_size:
  52. ref_model.resize_token_embeddings(config.vocab_size)
  53. logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
  54. ref_outputs = ref_model.forward(test_inputs, attention_mask=attention_mask).logits.float()
  55. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  56. logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
  57. del ref_model, ref_outputs
  58. else:
  59. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  60. assert False
  61. @pytest.mark.forked
  62. def test_greedy_generation(max_new_tokens=4):
  63. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  64. model = DistributedBloomForCausalLM.from_pretrained(
  65. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  66. )
  67. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  68. remote_outputs = model.generate(
  69. inputs,
  70. max_new_tokens=max_new_tokens,
  71. )
  72. hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
  73. assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
  74. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  75. "input_ids"
  76. ]
  77. remote_outputs_batch = model.generate(
  78. inputs_batch,
  79. max_new_tokens=max_new_tokens,
  80. )
  81. hf_outputs_batch = BloomForCausalLM.greedy_search(
  82. model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
  83. )
  84. assert torch.allclose(
  85. remote_outputs_batch, hf_outputs_batch
  86. ), "Greedy search results are not identical to HF in multibatch mode"
  87. @pytest.mark.forked
  88. @pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
  89. @pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
  90. def test_sampling(sampling_options, max_new_tokens=4):
  91. torch.manual_seed(0)
  92. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  93. model = DistributedBloomForCausalLM.from_pretrained(
  94. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  95. )
  96. logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
  97. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  98. with torch.random.fork_rng():
  99. remote_outputs = model.generate(
  100. inputs,
  101. max_new_tokens=max_new_tokens,
  102. do_sample=True,
  103. **sampling_options,
  104. )
  105. with torch.random.fork_rng():
  106. hf_outputs = BloomForCausalLM.sample(
  107. model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
  108. )
  109. assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
  110. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  111. "input_ids"
  112. ]
  113. with torch.random.fork_rng():
  114. remote_outputs_batch = model.generate(
  115. inputs_batch,
  116. max_new_tokens=max_new_tokens,
  117. do_sample=True,
  118. **sampling_options,
  119. )
  120. with torch.random.fork_rng():
  121. hf_outputs_batch = BloomForCausalLM.sample(
  122. model,
  123. input_ids=inputs_batch,
  124. max_length=inputs_batch.size(1) + max_new_tokens,
  125. logits_warper=logits_warper,
  126. )
  127. assert torch.allclose(
  128. remote_outputs_batch, hf_outputs_batch
  129. ), "Sampling results are not identical to HF in multibatch mode"
  130. @pytest.mark.forked
  131. def test_beam_search_generation(max_new_tokens=4, num_beams=2):
  132. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  133. model = DistributedBloomForCausalLM.from_pretrained(
  134. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  135. )
  136. text = "A cat sat on a mat"
  137. inputs = tokenizer(text, return_tensors="pt")["input_ids"]
  138. remote_outputs = model.generate(
  139. inputs,
  140. max_new_tokens=max_new_tokens,
  141. num_beams=num_beams,
  142. )
  143. beam_scorer = BeamSearchScorer(
  144. batch_size=inputs.size(0),
  145. num_beams=num_beams,
  146. device=inputs.device,
  147. length_penalty=0,
  148. do_early_stopping=False,
  149. )
  150. hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
  151. hf_outputs = BloomForCausalLM.beam_search(
  152. model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
  153. )
  154. assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"