|
@@ -11,20 +11,12 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-@pytest.fixture
|
|
|
-def tokenizer():
|
|
|
- return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
-
|
|
|
-
|
|
|
-@pytest.fixture
|
|
|
-def model():
|
|
|
- return DistributedBloomForCausalLM.from_pretrained(
|
|
|
+@pytest.mark.forked
|
|
|
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
)
|
|
|
-
|
|
|
-
|
|
|
-@pytest.mark.forked
|
|
|
-def test_full_model_exact_match(tokenizer, model, atol_forward=1e-3, atol_inference=1e-3):
|
|
|
assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|
|
@@ -65,31 +57,30 @@ def test_full_model_exact_match(tokenizer, model, atol_forward=1e-3, atol_infere
|
|
|
assert False
|
|
|
|
|
|
|
|
|
-def test_greedy_generation(tokenizer, model, max_new_tokens=4):
|
|
|
+@pytest.mark.forked
|
|
|
+def test_greedy_generation(max_new_tokens=4):
|
|
|
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
+ MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
+ )
|
|
|
inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
remote_outputs = model.generate(
|
|
|
inputs,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
)
|
|
|
- hf_outputs = BloomForCausalLM.greedy_search(
|
|
|
- model,
|
|
|
- input_ids=inputs,
|
|
|
- max_length=inputs.size(1) + max_new_tokens
|
|
|
- )
|
|
|
+ hf_outputs = BloomForCausalLM.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
|
|
|
assert torch.allclose(remote_outputs, hf_outputs), "Greedy search are not identical to HF"
|
|
|
|
|
|
- inputs_batch = tokenizer(
|
|
|
- ["A cat sat on a mat", "A dog sat on a mat"],
|
|
|
- return_tensors='pt',
|
|
|
- padding=True
|
|
|
- )["input_ids"]
|
|
|
+ inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
|
|
|
+ "input_ids"
|
|
|
+ ]
|
|
|
remote_outputs_batch = model.generate(
|
|
|
inputs_batch,
|
|
|
max_new_tokens=max_new_tokens,
|
|
|
)
|
|
|
hf_outputs_batch = BloomForCausalLM.greedy_search(
|
|
|
- model,
|
|
|
- input_ids=inputs_batch,
|
|
|
- max_length=inputs_batch.size(1) + max_new_tokens
|
|
|
+ model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
|
|
|
)
|
|
|
- assert torch.allclose(remote_outputs_batch, hf_outputs_batch), "Greedy search are not identical to HF in multibatch mode"
|
|
|
+ assert torch.allclose(
|
|
|
+ remote_outputs_batch, hf_outputs_batch
|
|
|
+ ), "Greedy search are not identical to HF in multibatch mode"
|