|
@@ -5,6 +5,7 @@
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
+import pytest
|
|
|
import torch
|
|
|
import transformers
|
|
|
from hivemind.moe.expert_uid import UID_DELIMITER, ExpertInfo
|
|
@@ -15,6 +16,7 @@ from src.client.remote_block import RemoteTransformerBlock
|
|
|
from src.dht_utils import get_remote_module
|
|
|
|
|
|
|
|
|
+@pytest.mark.forked
|
|
|
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
|
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
|
@@ -47,6 +49,7 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
|
|
|
assert torch.allclose(grads_ref, grads_rpc, rtol=0, atol=atol_backward)
|
|
|
|
|
|
|
|
|
+@pytest.mark.forked
|
|
|
def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
|
config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|