123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- import peft
- import pytest
- import torch
- import transformers
- from hivemind import get_logger
- from petals import AutoDistributedModelForCausalLM
- from test_utils import *
- logger = get_logger(__name__)
- @pytest.fixture
- def tokenizer():
- # We set use_fast=False since LlamaTokenizerFast is slow on load
- return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
- @pytest.fixture
- def model():
- return AutoDistributedModelForCausalLM.from_pretrained(
- MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
- )
- @pytest.fixture
- def ref_model():
- return transformers.AutoModelForCausalLM.from_pretrained(
- REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
- )
- @pytest.mark.forked
- @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
- @pytest.mark.parametrize("pass_empty_tensors", (True, False))
- def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
- if use_peft:
- model.config.active_adapter = ADAPTER_NAME
- ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
- ref_model.train(False)
- test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
- with torch.inference_mode():
- parallel_outputs = model.forward(test_inputs).logits
- assert torch.all(torch.isfinite(parallel_outputs))
- logger.info("Forward outputs are finite")
- embs = model.transformer.word_embeddings(test_inputs)
- embs = model.transformer.word_embeddings_layernorm(embs)
- recurrent_outputs = []
- with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
- if pass_empty_tensors:
- recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
- for t in range(embs.shape[1]):
- if t == 4:
- recurrent_outputs.append(sess.step(embs[:, 4:9, :]))
- elif 4 < t < 9:
- continue
- else:
- recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
- if t == 2 and pass_empty_tensors:
- recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
- recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
- recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
- recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
- recurrent_outputs = model.lm_head(recurrent_outputs)
- assert torch.allclose(
- recurrent_outputs, parallel_outputs, rtol=0, atol=atol
- ), "Inference differs from forward pass"
- ref_outputs = ref_model.forward(test_inputs).logits.float()
- assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
- def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
- if not multiple_calls:
- return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
- with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
- return torch.cat(
- [
- # Sessions provided both explicitly and implicitly should work
- model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
- model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
- model.generate(None, max_new_tokens=1, **kwargs),
- ],
- dim=1,
- )
- @pytest.mark.forked
- def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
- inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
- if tokenizer.pad_token_id is None:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
- "input_ids"
- ]
- options = dict(max_new_tokens=max_new_tokens, do_sample=False)
- for multiple_calls in [False, True]:
- for inputs in [inputs_single, inputs_batch]:
- outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
- ref_outputs = ref_model.generate(inputs, **options)
- assert torch.allclose(
- outputs, ref_outputs
- ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
- @pytest.mark.forked
- def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
- inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
- if tokenizer.pad_token_id is None:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
- "input_ids"
- ]
- for options in [
- dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
- dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
- ]:
- options.update(max_new_tokens=max_new_tokens)
- for multiple_calls in [False, True]:
- for inputs in [inputs_single, inputs_batch]:
- torch.manual_seed(0)
- outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
- torch.manual_seed(0)
- ref_outputs = ref_model.generate(inputs, **options)
- assert torch.allclose(
- outputs, ref_outputs
- ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
- @pytest.mark.forked
- def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
- inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
- options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
- outputs = make_generate_calls(model, inputs, **options)
- ref_outputs = ref_model.generate(inputs, **options)
- assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"
- @pytest.mark.forked
- def test_input_ids(tokenizer, model, ref_model, max_new_tokens=4):
- inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
- assert inputs.keys() == {"input_ids", "attention_mask"}
- outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
- ref_outputs = ref_model.generate(**inputs, max_new_tokens=max_new_tokens)
- assert torch.allclose(outputs, ref_outputs), f"Outputs are not identical to HF"
- with model.inference_session(max_length=inputs["input_ids"].shape[1] + max_new_tokens):
- outputs = torch.cat(
- [
- model.generate(**inputs, max_new_tokens=2),
- model.generate(None, max_new_tokens=max_new_tokens - 2),
- ],
- dim=1,
- )
- assert torch.allclose(outputs, ref_outputs), f"Multi-call outputs are not identical to HF"
|