|
@@ -1,200 +1,27 @@
|
|
|
import asyncio
|
|
|
import heapq
|
|
|
-import multiprocessing as mp
|
|
|
import random
|
|
|
-import signal
|
|
|
from itertools import product
|
|
|
-from typing import List, Sequence, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
-from multiaddr import Multiaddr
|
|
|
|
|
|
import hivemind
|
|
|
from hivemind import get_dht_time
|
|
|
from hivemind.dht.node import DHTID, DHTNode
|
|
|
-from hivemind.dht.protocol import DHTProtocol
|
|
|
-from hivemind.dht.storage import DictionaryDHTValue
|
|
|
-from hivemind.p2p import P2P, PeerID
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
-
|
|
|
-def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
|
|
|
- return list({PeerID.from_base58(maddr["p2p"]) for maddr in maddrs})
|
|
|
-
|
|
|
-
|
|
|
-def run_protocol_listener(
|
|
|
- dhtid: DHTID, maddr_conn: mp.connection.Connection, initial_peers: Sequence[Multiaddr]
|
|
|
-) -> None:
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
-
|
|
|
- p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
|
|
|
- visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
|
|
|
-
|
|
|
- protocol = loop.run_until_complete(
|
|
|
- DHTProtocol.create(p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5)
|
|
|
- )
|
|
|
-
|
|
|
- logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
|
|
|
-
|
|
|
- for peer_id in maddrs_to_peer_ids(initial_peers):
|
|
|
- loop.run_until_complete(protocol.call_ping(peer_id))
|
|
|
-
|
|
|
- maddr_conn.send((p2p.peer_id, visible_maddrs))
|
|
|
-
|
|
|
- async def shutdown():
|
|
|
- await p2p.shutdown()
|
|
|
- logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
|
|
|
- loop.stop()
|
|
|
-
|
|
|
- loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
|
|
|
- loop.run_forever()
|
|
|
-
|
|
|
-
|
|
|
-def launch_protocol_listener(
|
|
|
- initial_peers: Sequence[Multiaddr] = (),
|
|
|
-) -> Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
|
|
|
- remote_conn, local_conn = mp.Pipe()
|
|
|
- dht_id = DHTID.generate()
|
|
|
- process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
|
|
|
- process.start()
|
|
|
- peer_id, visible_maddrs = local_conn.recv()
|
|
|
-
|
|
|
- return dht_id, process, peer_id, visible_maddrs
|
|
|
-
|
|
|
-
|
|
|
# note: we run network-related tests in a separate process to re-initialize all global states from scratch
|
|
|
# this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
-def test_dht_protocol():
|
|
|
- peer1_node_id, peer1_proc, peer1_id, peer1_maddrs = launch_protocol_listener()
|
|
|
- peer2_node_id, peer2_proc, peer2_id, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
|
|
|
-
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- for client_mode in [True, False]: # note: order matters, this test assumes that first run uses client mode
|
|
|
- peer_id = DHTID.generate()
|
|
|
- p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
|
|
|
- protocol = loop.run_until_complete(
|
|
|
- DHTProtocol.create(
|
|
|
- p2p, peer_id, bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=client_mode
|
|
|
- )
|
|
|
- )
|
|
|
- logger.info(f"Self id={protocol.node_id}")
|
|
|
-
|
|
|
- assert loop.run_until_complete(protocol.call_ping(peer1_id)) == peer1_node_id
|
|
|
-
|
|
|
- key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
|
|
|
- store_ok = loop.run_until_complete(
|
|
|
- protocol.call_store(peer1_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
- )
|
|
|
- assert all(store_ok), "DHT rejected a trivial store"
|
|
|
-
|
|
|
- # peer 1 must know about peer 2
|
|
|
- (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(peer1_id, [key])
|
|
|
- )[key]
|
|
|
- recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
|
|
|
- (recv_id, recv_peer_id) = next(iter(nodes_found.items()))
|
|
|
- assert (
|
|
|
- recv_id == peer2_node_id and recv_peer_id == peer2_id
|
|
|
- ), f"expected id={peer2_node_id}, peer={peer2_id} but got {recv_id}, {recv_peer_id}"
|
|
|
-
|
|
|
- assert recv_value == value and recv_expiration == expiration, (
|
|
|
- f"call_find_value expected {value} (expires by {expiration}) "
|
|
|
- f"but got {recv_value} (expires by {recv_expiration})"
|
|
|
- )
|
|
|
-
|
|
|
- # peer 2 must know about peer 1, but not have a *random* nonexistent value
|
|
|
- dummy_key = DHTID.generate()
|
|
|
- empty_item, nodes_found_2 = loop.run_until_complete(protocol.call_find(peer2_id, [dummy_key]))[dummy_key]
|
|
|
- assert empty_item is None, "Non-existent keys shouldn't have values"
|
|
|
- (recv_id, recv_peer_id) = next(iter(nodes_found_2.items()))
|
|
|
- assert (
|
|
|
- recv_id == peer1_node_id and recv_peer_id == peer1_id
|
|
|
- ), f"expected id={peer1_node_id}, peer={peer1_id} but got {recv_id}, {recv_peer_id}"
|
|
|
-
|
|
|
- # cause a non-response by querying a nonexistent peer
|
|
|
- assert loop.run_until_complete(protocol.call_find(PeerID.from_base58("fakeid"), [key])) is None
|
|
|
-
|
|
|
- # store/get a dictionary with sub-keys
|
|
|
- nested_key, subkey1, subkey2 = DHTID.generate(), "foo", "bar"
|
|
|
- value1, value2 = [random.random(), {"ololo": "pyshpysh"}], "abacaba"
|
|
|
- assert loop.run_until_complete(
|
|
|
- protocol.call_store(
|
|
|
- peer1_id,
|
|
|
- keys=[nested_key],
|
|
|
- values=[hivemind.MSGPackSerializer.dumps(value1)],
|
|
|
- expiration_time=[expiration],
|
|
|
- subkeys=[subkey1],
|
|
|
- )
|
|
|
- )
|
|
|
- assert loop.run_until_complete(
|
|
|
- protocol.call_store(
|
|
|
- peer1_id,
|
|
|
- keys=[nested_key],
|
|
|
- values=[hivemind.MSGPackSerializer.dumps(value2)],
|
|
|
- expiration_time=[expiration + 5],
|
|
|
- subkeys=[subkey2],
|
|
|
- )
|
|
|
- )
|
|
|
- (recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(peer1_id, [nested_key])
|
|
|
- )[nested_key]
|
|
|
- assert isinstance(recv_dict, DictionaryDHTValue)
|
|
|
- assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
|
|
|
- assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
|
|
|
- assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
|
|
|
-
|
|
|
- if not client_mode:
|
|
|
- loop.run_until_complete(p2p.shutdown())
|
|
|
-
|
|
|
- peer1_proc.terminate()
|
|
|
- peer2_proc.terminate()
|
|
|
-
|
|
|
-
|
|
|
-@pytest.mark.forked
|
|
|
-def test_empty_table():
|
|
|
- """Test RPC methods with empty routing table"""
|
|
|
- peer_id, peer_proc, peer_peer_id, peer_maddrs = launch_protocol_listener()
|
|
|
-
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
|
|
|
- protocol = loop.run_until_complete(
|
|
|
- DHTProtocol.create(
|
|
|
- p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, client_mode=True
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- key, value, expiration = DHTID.generate(), [random.random(), {"ololo": "pyshpysh"}], get_dht_time() + 1e3
|
|
|
-
|
|
|
- empty_item, nodes_found = loop.run_until_complete(protocol.call_find(peer_peer_id, [key]))[key]
|
|
|
- assert empty_item is None and len(nodes_found) == 0
|
|
|
- assert all(
|
|
|
- loop.run_until_complete(
|
|
|
- protocol.call_store(peer_peer_id, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
- )
|
|
|
- ), "peer rejected store"
|
|
|
-
|
|
|
- (recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(peer_peer_id, [key])
|
|
|
- )[key]
|
|
|
- recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
|
|
|
- assert len(nodes_found) == 0
|
|
|
- assert recv_value == value and recv_expiration == expiration
|
|
|
-
|
|
|
- assert loop.run_until_complete(protocol.call_ping(peer_peer_id)) == peer_id
|
|
|
- assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58("fakeid"))) is None
|
|
|
- peer_proc.terminate()
|
|
|
-
|
|
|
-
|
|
|
-@pytest.mark.forked
|
|
|
-def test_dht_node(
|
|
|
+@pytest.mark.asyncio
|
|
|
+async def test_dht_node(
|
|
|
n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
|
|
|
):
|
|
|
# step A: create a swarm of 50 dht nodes in separate processes
|
|
@@ -205,26 +32,23 @@ def test_dht_node(
|
|
|
)
|
|
|
|
|
|
# step B: run 51-st node in this process
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
initial_peers = random.choice(swarm_maddrs)
|
|
|
- me = loop.run_until_complete(
|
|
|
- DHTNode.create(
|
|
|
- initial_peers=initial_peers,
|
|
|
- parallel_rpc=parallel_rpc,
|
|
|
- bucket_size=bucket_size,
|
|
|
- num_replicas=num_replicas,
|
|
|
- cache_refresh_before_expiry=False,
|
|
|
- )
|
|
|
+ me = await DHTNode.create(
|
|
|
+ initial_peers=initial_peers,
|
|
|
+ parallel_rpc=parallel_rpc,
|
|
|
+ bucket_size=bucket_size,
|
|
|
+ num_replicas=num_replicas,
|
|
|
+ cache_refresh_before_expiry=False,
|
|
|
)
|
|
|
|
|
|
# test 1: find self
|
|
|
- nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
|
|
|
+ nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
|
|
|
assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
|
|
|
|
|
|
# test 2: find others
|
|
|
for _ in range(10):
|
|
|
ref_peer_id, query_id = random.choice(list(dht.items()))
|
|
|
- nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
|
|
|
+ nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
|
|
|
assert len(nearest) == 1
|
|
|
found_node_id, found_peer_id = next(iter(nearest.items()))
|
|
|
assert found_node_id == query_id and found_peer_id == ref_peer_id
|
|
@@ -238,10 +62,8 @@ def test_dht_node(
|
|
|
query_id = DHTID.generate()
|
|
|
k_nearest = random.randint(1, 10)
|
|
|
exclude_self = random.random() > 0.5
|
|
|
- nearest = loop.run_until_complete(
|
|
|
- me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
|
|
|
- )[query_id]
|
|
|
- nearest_nodes = list(nearest) # keys from ordered dict
|
|
|
+ find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
|
|
|
+ nearest_nodes = list(find_result[query_id]) # keys from ordered dict
|
|
|
|
|
|
assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
|
|
|
assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
|
|
@@ -268,65 +90,63 @@ def test_dht_node(
|
|
|
|
|
|
# test 4: find all nodes
|
|
|
dummy = DHTID.generate()
|
|
|
- nearest = loop.run_until_complete(me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
|
|
|
+ nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
|
|
|
assert len(nearest) == len(dht) + 1
|
|
|
assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
|
|
|
|
|
|
# test 5: node without peers
|
|
|
- detached_node = loop.run_until_complete(DHTNode.create())
|
|
|
- nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
|
|
|
+ detached_node = await DHTNode.create()
|
|
|
+ nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
|
|
|
assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
|
|
|
- nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
|
|
|
+ nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
|
|
|
assert len(nearest) == 0
|
|
|
|
|
|
# test 6: store and get value
|
|
|
true_time = get_dht_time() + 1200
|
|
|
- assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
|
|
|
+ assert await me.store("mykey", ["Value", 10], true_time)
|
|
|
|
|
|
initial_peers = random.choice(swarm_maddrs)
|
|
|
- that_guy = loop.run_until_complete(
|
|
|
- DHTNode.create(
|
|
|
- initial_peers=initial_peers,
|
|
|
- parallel_rpc=parallel_rpc,
|
|
|
- cache_refresh_before_expiry=False,
|
|
|
- cache_locally=False,
|
|
|
- )
|
|
|
+ that_guy = await DHTNode.create(
|
|
|
+ initial_peers=initial_peers,
|
|
|
+ parallel_rpc=parallel_rpc,
|
|
|
+ cache_refresh_before_expiry=False,
|
|
|
+ cache_locally=False,
|
|
|
)
|
|
|
|
|
|
for node in [me, that_guy]:
|
|
|
- val, expiration_time = loop.run_until_complete(node.get("mykey"))
|
|
|
+ val, expiration_time = await node.get("mykey")
|
|
|
assert val == ["Value", 10], "Wrong value"
|
|
|
assert expiration_time == true_time, f"Wrong time"
|
|
|
|
|
|
- assert loop.run_until_complete(detached_node.get("mykey")) is None
|
|
|
+ assert not await detached_node.get("mykey")
|
|
|
|
|
|
# test 7: bulk store and bulk get
|
|
|
keys = "foo", "bar", "baz", "zzz"
|
|
|
values = 3, 2, "batman", [1, 2, 3]
|
|
|
- store_ok = loop.run_until_complete(me.store_many(keys, values, expiration_time=get_dht_time() + 999))
|
|
|
+ store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
|
|
|
assert all(store_ok.values()), "failed to store one or more keys"
|
|
|
- response = loop.run_until_complete(me.get_many(keys[::-1]))
|
|
|
+ response = await me.get_many(keys[::-1])
|
|
|
for key, value in zip(keys, values):
|
|
|
assert key in response and response[key][0] == value
|
|
|
|
|
|
# test 8: store dictionaries as values (with sub-keys)
|
|
|
upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
|
|
|
now = get_dht_time()
|
|
|
- assert loop.run_until_complete(me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10))
|
|
|
- assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20))
|
|
|
+ assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
|
|
|
+ assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
|
|
|
for node in [that_guy, me]:
|
|
|
- value, time = loop.run_until_complete(node.get(upper_key))
|
|
|
+ value, time = await node.get(upper_key)
|
|
|
assert isinstance(value, dict) and time == now + 20
|
|
|
assert value[subkey1] == (123, now + 10)
|
|
|
assert value[subkey2] == (456, now + 20)
|
|
|
assert len(value) == 2
|
|
|
|
|
|
- assert not loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10))
|
|
|
- assert loop.run_until_complete(me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30))
|
|
|
- assert loop.run_until_complete(me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50))
|
|
|
+ assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
|
|
|
+ assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
|
|
|
+ assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
|
|
|
|
|
|
for node in [that_guy, me]:
|
|
|
- value, time = loop.run_until_complete(node.get(upper_key, latest=True))
|
|
|
+ value, time = await node.get(upper_key, latest=True)
|
|
|
assert isinstance(value, dict) and time == now + 50, (value, time)
|
|
|
assert value[subkey1] == (123, now + 10)
|
|
|
assert value[subkey2] == (567, now + 30)
|
|
@@ -336,7 +156,7 @@ def test_dht_node(
|
|
|
for proc in processes:
|
|
|
proc.terminate()
|
|
|
# The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
|
|
|
- loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
|
|
|
+ await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|