|
@@ -3,29 +3,31 @@ import pytest
|
|
|
import torch
|
|
|
import transformers
|
|
|
from hivemind import get_logger
|
|
|
-from transformers.generation import BeamSearchScorer
|
|
|
-from transformers.models.bloom import BloomForCausalLM
|
|
|
+from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
|
|
|
|
|
|
-from petals import DistributedBloomForCausalLM
|
|
|
+from petals import AutoDistributedModelForCausalLM
|
|
|
from test_utils import *
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
+@pytest.fixture
|
|
|
+def tokenizer():
|
|
|
+ # We set use_fast=False since LlamaTokenizerFast is slow on load
|
|
|
+ return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
|
|
|
+
|
|
|
+
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
|
|
|
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
|
|
|
-def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
+def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
+ model = AutoDistributedModelForCausalLM.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
initial_peers=INITIAL_PEERS,
|
|
|
- low_cpu_mem_usage=True,
|
|
|
torch_dtype=torch.float32,
|
|
|
active_adapter=ADAPTER_NAME if use_peft else None,
|
|
|
)
|
|
|
config = model.config
|
|
|
- assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.num_hidden_layers
|
|
|
|
|
|
test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
|
|
@@ -63,7 +65,7 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
|
|
|
del model, embs, recurrent_outputs
|
|
|
|
|
|
if REF_NAME:
|
|
|
- ref_model = transformers.BloomForCausalLM.from_pretrained(
|
|
|
+ ref_model = transformers.AutoModelForCausalLM.from_pretrained(
|
|
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
)
|
|
|
if use_peft:
|
|
@@ -86,27 +88,29 @@ def test_full_model_exact_match(use_peft: bool, pass_empty_tensors: bool, atol_f
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
-def test_greedy_generation(max_new_tokens=4):
|
|
|
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
- MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
+def test_greedy_generation(tokenizer, max_new_tokens=4):
|
|
|
+ model = AutoDistributedModelForCausalLM.from_pretrained(
|
|
|
+ MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
|
|
)
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
remote_outputs = model.generate(
|
|
|
inputs,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
)
|
|
|
- hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
|
|
+ hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
|
|
|
|
|
|
+ if tokenizer.pad_token_id is None:
|
|
|
+ tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
|
|
"input_ids"
|
|
|
]
|
|
|
+
|
|
|
remote_outputs_batch = model.generate(
|
|
|
inputs_batch,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
)
|
|
|
- hf_outputs_batch = BloomForCausalLM.greedy_search(
|
|
|
+ hf_outputs_batch = HfGenerationMixin.greedy_search(
|
|
|
model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
|
|
|
)
|
|
|
assert torch.allclose(
|
|
@@ -117,13 +121,13 @@ def test_greedy_generation(max_new_tokens=4):
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
|
|
|
@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
|
|
|
-def test_sampling(sampling_options, max_new_tokens=4):
|
|
|
+def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
|
|
|
torch.manual_seed(0)
|
|
|
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
- MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
+
|
|
|
+ model = AutoDistributedModelForCausalLM.from_pretrained(
|
|
|
+ MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
|
|
)
|
|
|
- logits_warper = BloomForCausalLM._get_logits_warper(model, num_beams=1, **sampling_options)
|
|
|
+ logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
with torch.random.fork_rng():
|
|
|
remote_outputs = model.generate(
|
|
@@ -133,7 +137,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|
|
**sampling_options,
|
|
|
)
|
|
|
with torch.random.fork_rng():
|
|
|
- hf_outputs = BloomForCausalLM.sample(
|
|
|
+ hf_outputs = HfGenerationMixin.sample(
|
|
|
model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
|
|
|
)
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
|
|
@@ -149,7 +153,7 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|
|
**sampling_options,
|
|
|
)
|
|
|
with torch.random.fork_rng():
|
|
|
- hf_outputs_batch = BloomForCausalLM.sample(
|
|
|
+ hf_outputs_batch = HfGenerationMixin.sample(
|
|
|
model,
|
|
|
input_ids=inputs_batch,
|
|
|
max_length=inputs_batch.size(1) + max_new_tokens,
|
|
@@ -161,10 +165,9 @@ def test_sampling(sampling_options, max_new_tokens=4):
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
-def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
|
|
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
- MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
+def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
|
|
|
+ model = AutoDistributedModelForCausalLM.from_pretrained(
|
|
|
+ MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
|
|
|
)
|
|
|
text = "A cat sat on a mat"
|
|
|
inputs = tokenizer(text, return_tensors="pt")["input_ids"]
|
|
@@ -181,7 +184,7 @@ def test_beam_search_generation(max_new_tokens=4, num_beams=2):
|
|
|
do_early_stopping=False,
|
|
|
)
|
|
|
hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
|
|
|
- hf_outputs = BloomForCausalLM.beam_search(
|
|
|
+ hf_outputs = HfGenerationMixin.beam_search(
|
|
|
model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
|
|
|
)
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
|