@@ -11,10 +11,14 @@ use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__)
-@pytest.mark.forked
-def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
- model = DistributedBloomForCausalLM.from_pretrained(
+@pytest.fixture
+def tokenizer():
+ return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+
+def model():
+ return DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
config = model.config