|
@@ -2,6 +2,7 @@ import pytest
|
|
|
import torch
|
|
|
import transformers
|
|
|
from hivemind import get_logger, use_hivemind_log_handler
|
|
|
+
|
|
|
from test_utils import *
|
|
|
|
|
|
from src.client.remote_model import DistributedBloomForCausalLM
|
|
@@ -13,13 +14,14 @@ logger = get_logger(__file__)
|
|
|
@pytest.mark.forked
|
|
|
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS,
|
|
|
+ low_cpu_mem_usage=True, torch_dtype=torch.float32)
|
|
|
assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|
|
|
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
|
|
|
|
|
- with torch.no_grad():
|
|
|
+ with torch.inference_mode():
|
|
|
parallel_outputs = model.forward(test_inputs).logits
|
|
|
assert torch.all(torch.isfinite(parallel_outputs))
|
|
|
logger.info("Forward outputs are finite")
|
|
@@ -36,14 +38,15 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
|
|
logger.info("Inference is consistent with forward")
|
|
|
|
|
|
- del model, recurrent_outputs
|
|
|
+ del model, embs, recurrent_outputs
|
|
|
|
|
|
if REF_NAME:
|
|
|
- ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
|
|
+ ref_model = transformers.BloomForCausalLM.from_pretrained(
|
|
|
+ REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)
|
|
|
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
# prior to https://github.com/huggingface/transformers/pull/17837
|
|
|
- ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
|
|
|
+ ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
|
|
|
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
|
|
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
|
|
|
del ref_model, ref_outputs, dummy_mask
|