@@ -11,14 +11,10 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
-@pytest.fixture
-def tokenizer():
- return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-
-def model():
- return DistributedBloomForCausalLM.from_pretrained(
+@pytest.mark.forked
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+ model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
config = model.config