Artem Chumachenko 3 years ago
parent
commit
f62c65ec23
1 changed files with 4 additions and 8 deletions
  1. 4 8
      tests/test_full_model.py

+ 4 - 8
tests/test_full_model.py

@@ -11,14 +11,10 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
-@pytest.fixture
-def tokenizer():
-    return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
-
-
-@pytest.fixture
-def model():
-    return DistributedBloomForCausalLM.from_pretrained(
+@pytest.mark.forked
+def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
         MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
     )
     config = model.config