Sfoglia il codice sorgente

run each test in a forked process

justheuristic 3 anni fa
parent
commit
57a3a065a9

+ 2 - 0
tests/test_block_exact_match.py

@@ -1,6 +1,7 @@
 import random
 
 import hivemind
+import pytest
 import torch
 import transformers
 from test_utils import *
@@ -11,6 +12,7 @@ from src.data_structures import UID_DELIMITER
 from src.dht_utils import get_remote_module
 
 
+@pytest.mark.forked
 def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)

+ 3 - 0
tests/test_chained_calls.py

@@ -5,6 +5,7 @@
 
 
 import hivemind
+import pytest
 import torch
 import transformers
 from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
@@ -15,6 +16,7 @@ from src.client.remote_block import RemoteTransformerBlock
 from src.dht_utils import get_remote_module
 
 
+@pytest.mark.forked
 def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
@@ -47,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
     assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
 
 
+@pytest.mark.forked
 def test_chained_inference_exact_match(atol_inference=1e-4):
     dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
     config = transformers.AutoConfig.from_pretrained(MODEL_NAME)

+ 2 - 0
tests/test_full_model.py

@@ -1,3 +1,4 @@
+import pytest
 import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
@@ -9,6 +10,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
+@pytest.mark.forked
 def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
     tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
     model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)

+ 2 - 0
tests/test_remote_sequential.py

@@ -1,3 +1,4 @@
+import pytest
 import torch
 from hivemind import DHT, get_logger, use_hivemind_log_handler
 from test_utils import *
@@ -9,6 +10,7 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__file__)
 
 
+@pytest.mark.forked
 def test_remote_sequential():
     config = DistributedBloomConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
     dht = DHT(initial_peers=config.initial_peers, client_mode=True, start=True)