|
@@ -24,7 +24,7 @@ if not MODEL_NAME:
|
|
|
REF_NAME = os.environ.get("REF_NAME")
|
|
|
|
|
|
|
|
|
-def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3, prefix="bloom6b3"):
|
|
|
+def test_full_model_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
|
|
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
|
|
assert len(model.transformer.h) == model.config.n_layer
|