|
@@ -17,7 +17,7 @@ from src.dht_utils import get_remote_module
|
|
|
@pytest.mark.forked
|
|
|
def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
|
|
|
dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=True, start=True)
|
|
|
- config = src.DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
|
|
+ config = DistributedBloomConfig.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
for block_index in random.sample(range(config.n_layer), 3):
|
|
|
remote_block = get_remote_module(dht, f"{MODEL_NAME}{UID_DELIMITER}{block_index}", config)
|