|
@@ -11,10 +11,14 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
|
|
|
-@pytest.mark.forked
|
|
|
-def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
|
|
- tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
- model = DistributedBloomForCausalLM.from_pretrained(
|
|
|
+@pytest.fixture
|
|
|
+def tokenizer():
|
|
|
+ return transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def model():
|
|
|
+ return DistributedBloomForCausalLM.from_pretrained(
|
|
|
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
|
|
)
|
|
|
config = model.config
|