|
@@ -13,7 +13,7 @@ logger = get_logger(__file__)
|
|
|
@pytest.mark.forked
|
|
|
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype="auto")
|
|
|
+ model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
|
assert isinstance(model, DistributedBloomForCausalLM)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|
|
|
|