|
@@ -17,6 +17,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
)
|
|
|
+ config = model.config
|
|
|
assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|
|
@@ -45,6 +46,10 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
ref_model = transformers.BloomForCausalLM.from_pretrained(
|
|
|
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
)
|
|
|
+ if config.vocab_size < ref_model.config.vocab_size:
|
|
|
+ ref_model.resize_token_embeddings(config.vocab_size)
|
|
|
+ logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
|
|
|
+
|
|
|
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
|
|
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
|
|
# prior to https://github.com/huggingface/transformers/pull/17837
|