test_full_model.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import pytest
  2. import torch
  3. import transformers
  4. from hivemind import get_logger, use_hivemind_log_handler
  5. from test_utils import *
  6. from transformers.generation_utils import BeamSearchScorer
  7. from src.bloom.model import BloomForCausalLM
  8. from src.client.remote_model import DistributedBloomForCausalLM
  9. use_hivemind_log_handler("in_root_logger")
  10. logger = get_logger(__file__)
  11. @pytest.mark.forked
  12. def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
  13. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  14. model = DistributedBloomForCausalLM.from_pretrained(
  15. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  16. )
  17. config = model.config
  18. assert isinstance(model, DistributedBloomForCausalLM)
  19. assert len(model.transformer.h) == model.config.n_layer
  20. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  21. with torch.inference_mode():
  22. parallel_outputs = model.forward(test_inputs).logits
  23. assert torch.all(torch.isfinite(parallel_outputs))
  24. logger.info("Forward outputs are finite")
  25. embs = model.transformer.word_embeddings(test_inputs)
  26. embs = model.transformer.word_embeddings_layernorm(embs)
  27. recurrent_outputs = []
  28. with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
  29. for t in range(embs.shape[1]):
  30. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  31. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  32. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  33. recurrent_outputs = model.lm_head(recurrent_outputs)
  34. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  35. logger.info("Inference is consistent with forward")
  36. del model, embs, recurrent_outputs
  37. if REF_NAME:
  38. ref_model = transformers.BloomForCausalLM.from_pretrained(
  39. REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
  40. )
  41. if config.vocab_size < ref_model.config.vocab_size:
  42. ref_model.resize_token_embeddings(config.vocab_size)
  43. logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
  44. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  45. # note: this creates a dummy mask to make the test compatible with older transformer versions
  46. # prior to https://github.com/huggingface/transformers/pull/17837
  47. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
  48. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  49. logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
  50. del ref_model, ref_outputs, dummy_mask
  51. else:
  52. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
  53. assert False
  54. @pytest.mark.forked
  55. def test_greedy_generation(max_new_tokens=4):
  56. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  57. model = DistributedBloomForCausalLM.from_pretrained(
  58. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  59. )
  60. inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  61. remote_outputs = model.generate(
  62. inputs,
  63. max_new_tokens=max_new_tokens,
  64. )
  65. hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
  66. assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
  67. inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
  68. "input_ids"
  69. ]
  70. remote_outputs_batch = model.generate(
  71. inputs_batch,
  72. max_new_tokens=max_new_tokens,
  73. )
  74. hf_outputs_batch = BloomForCausalLM.greedy_search(
  75. model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
  76. )
  77. assert torch.allclose(
  78. remote_outputs_batch, hf_outputs_batch
  79. ), "Greedy search are not identical to HF in multibatch mode"
  80. @pytest.mark.forked
  81. def test_beam_search_generation(max_new_tokens=4, num_beams=2):
  82. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  83. model = DistributedBloomForCausalLM.from_pretrained(
  84. MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
  85. )
  86. text = "A cat sat on a mat"
  87. inputs = tokenizer(text, return_tensors="pt")["input_ids"]
  88. remote_outputs = model.generate(
  89. inputs,
  90. max_new_tokens=max_new_tokens,
  91. num_beams=num_beams,
  92. )
  93. beam_scorer = BeamSearchScorer(
  94. batch_size=inputs.size(0),
  95. num_beams=num_beams,
  96. device=inputs.device,
  97. length_penalty=0,
  98. do_early_stopping=False,
  99. )
  100. hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
  101. hf_outputs = BloomForCausalLM.beam_search(
  102. model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
  103. )
  104. assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"