|
@@ -13,16 +13,18 @@ from test_utils import *
|
|
|
|
|
|
import src
|
|
import src
|
|
from src.bloom.from_pretrained import load_pretrained_block
|
|
from src.bloom.from_pretrained import load_pretrained_block
|
|
-from src.client.remote_sequential import RemoteSequential
|
|
|
|
|
|
+from src.client.remote_sequential import RemoteTransformerBlock
|
|
|
|
+from src.data_structures import UID_DELIMITER
|
|
|
|
+from src.dht_utils import get_remote_module
|
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
|
def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq_length=1):
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
- config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
|
|
|
- remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
|
|
|
|
|
|
+ config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
|
|
|
+ remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
- assert isinstance(remote_block, RemoteSequential)
|
|
|
|
|
|
+ assert isinstance(remote_block, RemoteTransformerBlock)
|
|
|
|
|
|
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
|
|
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4 {MODEL_NAME}.5", remote_block._info.peer_id)
|
|
@@ -52,10 +54,10 @@ def test_forward_backward_exact_match(atol_forward=1e-4, atol_backward=1e-4, seq
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
def test_chained_inference_exact_match(atol_inference=1e-4):
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
- config = transformers.AutoConfig.from_pretrained(MODEL_NAME)
|
|
|
|
- remote_block = RemoteSequential(src.DistributedBloomConfig.from_pretrained(MODEL_NAME), dht)[0]
|
|
|
|
|
|
+ config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
|
|
|
+ remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}0", config)
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
assert remote_block is not None, f"Could not find {MODEL_NAME}{UID_DELIMITER}0 in DHT"
|
|
- assert isinstance(remote_block, RemoteSequential)
|
|
|
|
|
|
+ assert isinstance(remote_block, RemoteTransformerBlock)
|
|
|
|
|
|
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
_ = remote_block.info # lazy-init info now, because otherwise we will _break_ info init by chaning _info
|
|
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
|
|
remote_block._info = ExpertInfo(f"{MODEL_NAME}.3 {MODEL_NAME}.4", remote_block._info.peer_id)
|