test_full_model.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. import torch
  3. import transformers
  4. from hivemind import get_logger, use_hivemind_log_handler
  5. from src.client.remote_model import DistributedBloomForCausalLM
  6. use_hivemind_log_handler("in_root_logger")
  7. logger = get_logger(__file__)
  8. INITIAL_PEERS = os.environ.get("INITIAL_PEERS")
  9. if not INITIAL_PEERS:
  10. raise RuntimeError("Must specify INITIAL_PEERS environment variable with one or more peer ids")
  11. INITIAL_PEERS = INITIAL_PEERS.split()
  12. MODEL_NAME = os.environ.get("MODEL_NAME")
  13. if not MODEL_NAME:
  14. raise RuntimeError("Must specify MODEL_NAME as an index of a transformer block to be tested")
  15. REF_NAME = os.environ.get("REF_NAME")
  16. def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
  17. tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
  18. model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
  19. assert isinstance(model, DistributedBloomForCausalLM)
  20. assert len(model.transformer.h) == model.config.n_layer
  21. test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
  22. with torch.no_grad():
  23. parallel_outputs = model.forward(test_inputs).logits
  24. assert torch.all(torch.isfinite(parallel_outputs))
  25. logger.info("Forward outputs are finite")
  26. embs = model.transformer.word_embeddings(test_inputs)
  27. embs = model.transformer.word_embeddings_layernorm(embs)
  28. recurrent_outputs = []
  29. with model.transformer.h.inference_session() as sess:
  30. for t in range(embs.shape[1]):
  31. recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
  32. recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
  33. recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
  34. dictionary = model.transformer.word_embeddings.weight.t()
  35. recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
  36. recurrent_outputs = (recurrent_outputs @ dictionary).float()
  37. assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
  38. logger.info("Inference is consistent with forward")
  39. del model, recurrent_outputs
  40. if REF_NAME:
  41. ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
  42. dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
  43. # note: this creates a dummy mask to make the test compatible with older transformer versions
  44. # prior to https://github.com/huggingface/transformers/pull/17837
  45. ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
  46. assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
  47. logger.warning(f"{type(model)}.forward is consistent with {type(ref_model)}.forward")
  48. del ref_model, ref_outputs, dummy_mask
  49. else:
  50. logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")