Artem Chumachenko vor 3 Jahren
Ursprung
Commit
53e19de6e0
1 geänderte Dateien mit 8 neuen und 4 gelöschten Zeilen
  1. 8 4
      tests/test_full_model.py

+ 8 - 4
tests/test_full_model.py

@@ -11,10 +11,14 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-@pytest.mark.forked
-def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
-    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-    model = DistributedBloomForCausalLM.from_pretrained(
+@pytest.fixture
+def tokenizer():
+    return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture
+def model():
+    return DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
     config = model.config