|
@@ -2,72 +2,99 @@ import asyncio
|
|
|
import heapq
|
|
|
import multiprocessing as mp
|
|
|
import random
|
|
|
+import signal
|
|
|
from itertools import product
|
|
|
-from typing import Optional, List, Dict
|
|
|
+from typing import List, Sequence, Tuple
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
|
+from multiaddr import Multiaddr
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind import get_dht_time, replace_port
|
|
|
-from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
|
|
|
-from hivemind.dht.protocol import DHTProtocol, ValidationError
|
|
|
+from hivemind import get_dht_time
|
|
|
+from hivemind.dht.node import DHTID, DHTNode
|
|
|
+from hivemind.dht.protocol import DHTProtocol
|
|
|
from hivemind.dht.storage import DictionaryDHTValue
|
|
|
+from hivemind.p2p import P2P, PeerID
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
+from test_utils.dht_swarms import launch_swarm_in_separate_processes, launch_star_shaped_swarm
|
|
|
|
|
|
|
|
|
-def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[Endpoint] = None):
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
|
|
|
+ return list({PeerID.from_base58(maddr['p2p']) for maddr in maddrs})
|
|
|
+
|
|
|
+
|
|
|
+def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
|
|
|
+ initial_peers: Sequence[Multiaddr]) -> None:
|
|
|
loop = asyncio.get_event_loop()
|
|
|
+
|
|
|
+ p2p = loop.run_until_complete(P2P.create(initial_peers=initial_peers))
|
|
|
+ visible_maddrs = loop.run_until_complete(p2p.get_visible_maddrs())
|
|
|
+
|
|
|
protocol = loop.run_until_complete(DHTProtocol.create(
|
|
|
- dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5, listen_on=f"{LOCALHOST}:{port}"))
|
|
|
+ p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5))
|
|
|
+
|
|
|
+ logger.info(f"Started peer id={protocol.node_id} visible_maddrs={visible_maddrs}")
|
|
|
|
|
|
- assert protocol.port == port
|
|
|
- print(f"Started peer id={protocol.node_id} port={port}", flush=True)
|
|
|
+ for endpoint in maddrs_to_peer_ids(initial_peers):
|
|
|
+ loop.run_until_complete(protocol.call_ping(endpoint))
|
|
|
|
|
|
- if ping is not None:
|
|
|
- loop.run_until_complete(protocol.call_ping(ping))
|
|
|
- started.set()
|
|
|
- loop.run_until_complete(protocol.server.wait_for_termination())
|
|
|
- print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
|
|
|
+ maddr_conn.send((p2p.id, visible_maddrs))
|
|
|
|
|
|
+ async def shutdown():
|
|
|
+ await p2p.shutdown()
|
|
|
+ logger.info(f"Finished peer id={protocol.node_id} maddrs={visible_maddrs}")
|
|
|
+ loop.stop()
|
|
|
|
|
|
-# note: we run grpc-related tests in a separate process to re-initialize all global states from scratch
|
|
|
-# this helps us avoid undesirable side-effects (e.g. segfaults) when running multiple tests in sequence
|
|
|
+ loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
|
|
|
+ loop.run_forever()
|
|
|
+
|
|
|
+
|
|
|
+def launch_protocol_listener(initial_peers: Sequence[Multiaddr] = ()) -> \
|
|
|
+ Tuple[DHTID, mp.Process, PeerID, List[Multiaddr]]:
|
|
|
+ remote_conn, local_conn = mp.Pipe()
|
|
|
+ dht_id = DHTID.generate()
|
|
|
+ process = mp.Process(target=run_protocol_listener, args=(dht_id, remote_conn, initial_peers), daemon=True)
|
|
|
+ process.start()
|
|
|
+ peer_id, visible_maddrs = local_conn.recv()
|
|
|
+
|
|
|
+ return dht_id, process, peer_id, visible_maddrs
|
|
|
+
|
|
|
+
|
|
|
+# note: we run network-related tests in a separate process to re-initialize all global states from scratch
|
|
|
+# this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_dht_protocol():
|
|
|
- # create the first peer
|
|
|
- peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
|
|
|
- peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
|
|
|
- peer1_proc.start(), peer1_started.wait()
|
|
|
-
|
|
|
- # create another peer that connects to the first peer
|
|
|
- peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
|
|
|
- peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
|
|
|
- kwargs={'ping': f'{LOCALHOST}:{peer1_port}'}, daemon=True)
|
|
|
- peer2_proc.start(), peer2_started.wait()
|
|
|
+ peer1_id, peer1_proc, peer1_endpoint, peer1_maddrs = launch_protocol_listener()
|
|
|
+ peer2_id, peer2_proc, peer2_endpoint, _ = launch_protocol_listener(initial_peers=peer1_maddrs)
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
for listen in [False, True]: # note: order matters, this test assumes that first run uses listen=False
|
|
|
+ p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
|
|
|
protocol = loop.run_until_complete(DHTProtocol.create(
|
|
|
- DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
|
|
|
- print(f"Self id={protocol.node_id}", flush=True)
|
|
|
+ p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
|
|
|
+ logger.info(f"Self id={protocol.node_id}")
|
|
|
|
|
|
- assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer1_port}')) == peer1_id
|
|
|
+ assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id
|
|
|
|
|
|
key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
|
|
|
store_ok = loop.run_until_complete(protocol.call_store(
|
|
|
- f'{LOCALHOST}:{peer1_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
+ peer1_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
)
|
|
|
assert all(store_ok), "DHT rejected a trivial store"
|
|
|
|
|
|
# peer 1 must know about peer 2
|
|
|
(recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(f'{LOCALHOST}:{peer1_port}', [key]))[key]
|
|
|
+ protocol.call_find(peer1_endpoint, [key]))[key]
|
|
|
recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
|
|
|
(recv_id, recv_endpoint) = next(iter(nodes_found.items()))
|
|
|
- assert recv_id == peer2_id and ':'.join(recv_endpoint.split(':')[-2:]) == f"{LOCALHOST}:{peer2_port}", \
|
|
|
- f"expected id={peer2_id}, peer={LOCALHOST}:{peer2_port} but got {recv_id}, {recv_endpoint}"
|
|
|
+ assert recv_id == peer2_id and recv_endpoint == peer2_endpoint, \
|
|
|
+ f"expected id={peer2_id}, peer={peer2_endpoint} but got {recv_id}, {recv_endpoint}"
|
|
|
|
|
|
assert recv_value == value and recv_expiration == expiration, \
|
|
|
f"call_find_value expected {value} (expires by {expiration}) " \
|
|
@@ -76,38 +103,35 @@ def test_dht_protocol():
|
|
|
# peer 2 must know about peer 1, but not have a *random* nonexistent value
|
|
|
dummy_key = DHTID.generate()
|
|
|
empty_item, nodes_found_2 = loop.run_until_complete(
|
|
|
- protocol.call_find(f'{LOCALHOST}:{peer2_port}', [dummy_key]))[dummy_key]
|
|
|
+ protocol.call_find(peer2_endpoint, [dummy_key]))[dummy_key]
|
|
|
assert empty_item is None, "Non-existent keys shouldn't have values"
|
|
|
(recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
|
|
|
- assert recv_id == peer1_id and recv_endpoint == f"{LOCALHOST}:{peer1_port}", \
|
|
|
- f"expected id={peer1_id}, peer={LOCALHOST}:{peer1_port} but got {recv_id}, {recv_endpoint}"
|
|
|
+ assert recv_id == peer1_id and recv_endpoint == peer1_endpoint, \
|
|
|
+ f"expected id={peer1_id}, peer={peer1_endpoint} but got {recv_id}, {recv_endpoint}"
|
|
|
|
|
|
# cause a non-response by querying a nonexistent peer
|
|
|
- dummy_port = hivemind.find_open_port()
|
|
|
- assert loop.run_until_complete(protocol.call_find(f"{LOCALHOST}:{dummy_port}", [key])) is None
|
|
|
+ assert loop.run_until_complete(protocol.call_find(PeerID.from_base58('fakeid'), [key])) is None
|
|
|
|
|
|
# store/get a dictionary with sub-keys
|
|
|
nested_key, subkey1, subkey2 = DHTID.generate(), 'foo', 'bar'
|
|
|
value1, value2 = [random.random(), {'ololo': 'pyshpysh'}], 'abacaba'
|
|
|
assert loop.run_until_complete(protocol.call_store(
|
|
|
- f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
|
|
|
+ peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value1)],
|
|
|
expiration_time=[expiration], subkeys=[subkey1])
|
|
|
)
|
|
|
assert loop.run_until_complete(protocol.call_store(
|
|
|
- f'{LOCALHOST}:{peer1_port}', keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
|
|
|
+ peer1_endpoint, keys=[nested_key], values=[hivemind.MSGPackSerializer.dumps(value2)],
|
|
|
expiration_time=[expiration + 5], subkeys=[subkey2])
|
|
|
)
|
|
|
(recv_dict, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(f'{LOCALHOST}:{peer1_port}', [nested_key]))[nested_key]
|
|
|
+ protocol.call_find(peer1_endpoint, [nested_key]))[nested_key]
|
|
|
assert isinstance(recv_dict, DictionaryDHTValue)
|
|
|
assert len(recv_dict.data) == 2 and recv_expiration == expiration + 5
|
|
|
assert recv_dict.data[subkey1] == (protocol.serializer.dumps(value1), expiration)
|
|
|
assert recv_dict.data[subkey2] == (protocol.serializer.dumps(value2), expiration + 5)
|
|
|
|
|
|
- assert LOCALHOST in loop.run_until_complete(protocol.get_outgoing_request_endpoint(f'{LOCALHOST}:{peer1_port}'))
|
|
|
-
|
|
|
if listen:
|
|
|
- loop.run_until_complete(protocol.shutdown())
|
|
|
+ loop.run_until_complete(p2p.shutdown())
|
|
|
|
|
|
peer1_proc.terminate()
|
|
|
peer2_proc.terminate()
|
|
@@ -116,83 +140,63 @@ def test_dht_protocol():
|
|
|
@pytest.mark.forked
|
|
|
def test_empty_table():
|
|
|
""" Test RPC methods with empty routing table """
|
|
|
- peer_port, peer_id, peer_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
|
|
|
- peer_proc = mp.Process(target=run_protocol_listener, args=(peer_port, peer_id, peer_started), daemon=True)
|
|
|
- peer_proc.start(), peer_started.wait()
|
|
|
+ peer_id, peer_proc, peer_endpoint, peer_maddrs = launch_protocol_listener()
|
|
|
|
|
|
loop = asyncio.get_event_loop()
|
|
|
+ p2p = loop.run_until_complete(P2P.create(initial_peers=peer_maddrs))
|
|
|
protocol = loop.run_until_complete(DHTProtocol.create(
|
|
|
- DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
|
|
|
+ p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=False))
|
|
|
|
|
|
key, value, expiration = DHTID.generate(), [random.random(), {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
|
|
|
|
|
|
empty_item, nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
|
|
|
+ protocol.call_find(peer_endpoint, [key]))[key]
|
|
|
assert empty_item is None and len(nodes_found) == 0
|
|
|
assert all(loop.run_until_complete(protocol.call_store(
|
|
|
- f'{LOCALHOST}:{peer_port}', [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
+ peer_endpoint, [key], [hivemind.MSGPackSerializer.dumps(value)], expiration)
|
|
|
)), "peer rejected store"
|
|
|
|
|
|
(recv_value_bytes, recv_expiration), nodes_found = loop.run_until_complete(
|
|
|
- protocol.call_find(f'{LOCALHOST}:{peer_port}', [key]))[key]
|
|
|
+ protocol.call_find(peer_endpoint, [key]))[key]
|
|
|
recv_value = hivemind.MSGPackSerializer.loads(recv_value_bytes)
|
|
|
assert len(nodes_found) == 0
|
|
|
assert recv_value == value and recv_expiration == expiration
|
|
|
|
|
|
- assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{peer_port}')) == peer_id
|
|
|
- assert loop.run_until_complete(protocol.call_ping(f'{LOCALHOST}:{hivemind.find_open_port()}')) is None
|
|
|
+ assert loop.run_until_complete(protocol.call_ping(peer_endpoint)) == peer_id
|
|
|
+ assert loop.run_until_complete(protocol.call_ping(PeerID.from_base58('fakeid'))) is None
|
|
|
peer_proc.terminate()
|
|
|
|
|
|
|
|
|
-def run_node(node_id, peers, status_pipe: mp.Pipe):
|
|
|
- if asyncio.get_event_loop().is_running():
|
|
|
- asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
|
|
|
- asyncio.set_event_loop(asyncio.new_event_loop())
|
|
|
- loop = asyncio.get_event_loop()
|
|
|
- node = loop.run_until_complete(DHTNode.create(node_id, initial_peers=peers))
|
|
|
- status_pipe.send(node.port)
|
|
|
- while True:
|
|
|
- loop.run_forever()
|
|
|
-
|
|
|
-
|
|
|
@pytest.mark.forked
|
|
|
def test_dht_node():
|
|
|
- # create dht with 50 nodes + your 51-st node
|
|
|
- dht: Dict[Endpoint, DHTID] = {}
|
|
|
- processes: List[mp.Process] = []
|
|
|
-
|
|
|
- for i in range(50):
|
|
|
- node_id = DHTID.generate()
|
|
|
- peers = random.sample(dht.keys(), min(len(dht), 5))
|
|
|
- pipe_recv, pipe_send = mp.Pipe(duplex=False)
|
|
|
- proc = mp.Process(target=run_node, args=(node_id, peers, pipe_send), daemon=True)
|
|
|
- proc.start()
|
|
|
- port = pipe_recv.recv()
|
|
|
- processes.append(proc)
|
|
|
- dht[f"{LOCALHOST}:{port}"] = node_id
|
|
|
+ # step A: create a swarm of 50 dht nodes in separate processes
|
|
|
+ # (first 5 created sequentially, others created in parallel)
|
|
|
+ processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(n_peers=50, n_sequential_peers=5)
|
|
|
|
|
|
+ # step B: run 51-st node in this process
|
|
|
loop = asyncio.get_event_loop()
|
|
|
- me = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 5), parallel_rpc=10,
|
|
|
+ initial_peers = random.choice(swarm_maddrs)
|
|
|
+ me = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
|
|
|
cache_refresh_before_expiry=False))
|
|
|
|
|
|
# test 1: find self
|
|
|
nearest = loop.run_until_complete(me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
|
|
|
- assert len(nearest) == 1 and ':'.join(nearest[me.node_id].split(':')[-2:]) == f"{LOCALHOST}:{me.port}"
|
|
|
+ assert len(nearest) == 1 and nearest[me.node_id] == me.endpoint
|
|
|
|
|
|
# test 2: find others
|
|
|
- for i in range(10):
|
|
|
+ for _ in range(10):
|
|
|
ref_endpoint, query_id = random.choice(list(dht.items()))
|
|
|
nearest = loop.run_until_complete(me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
|
|
|
assert len(nearest) == 1
|
|
|
found_node_id, found_endpoint = next(iter(nearest.items()))
|
|
|
- assert found_node_id == query_id and ':'.join(found_endpoint.split(':')[-2:]) == ref_endpoint
|
|
|
+ assert found_node_id == query_id and found_endpoint == ref_endpoint
|
|
|
|
|
|
# test 3: find neighbors to random nodes
|
|
|
accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
|
|
|
jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
|
|
|
all_node_ids = list(dht.values())
|
|
|
|
|
|
- for i in range(10):
|
|
|
+ for _ in range(10):
|
|
|
query_id = DHTID.generate()
|
|
|
k_nearest = random.randint(1, 10)
|
|
|
exclude_self = random.random() > 0.5
|
|
@@ -217,9 +221,9 @@ def test_dht_node():
|
|
|
jaccard_denominator += k_nearest
|
|
|
|
|
|
accuracy = accuracy_numerator / accuracy_denominator
|
|
|
- print("Top-1 accuracy:", accuracy) # should be 98-100%
|
|
|
+ logger.debug(f"Top-1 accuracy: {accuracy}") # should be 98-100%
|
|
|
jaccard_index = jaccard_numerator / jaccard_denominator
|
|
|
- print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
|
|
|
+ logger.debug(f"Jaccard index (intersection over union): {jaccard_index}") # should be 95-100%
|
|
|
assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
|
|
|
assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
|
|
|
|
|
@@ -232,14 +236,16 @@ def test_dht_node():
|
|
|
# test 5: node without peers
|
|
|
detached_node = loop.run_until_complete(DHTNode.create())
|
|
|
nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy]))[dummy]
|
|
|
- assert len(nearest) == 1 and nearest[detached_node.node_id] == f"{LOCALHOST}:{detached_node.port}"
|
|
|
+ assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.endpoint
|
|
|
nearest = loop.run_until_complete(detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
|
|
|
assert len(nearest) == 0
|
|
|
|
|
|
- # test 6 store and get value
|
|
|
+ # test 6: store and get value
|
|
|
true_time = get_dht_time() + 1200
|
|
|
assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
|
|
|
- that_guy = loop.run_until_complete(DHTNode.create(initial_peers=random.sample(dht.keys(), 3), parallel_rpc=10,
|
|
|
+
|
|
|
+ initial_peers = random.choice(swarm_maddrs)
|
|
|
+ that_guy = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, parallel_rpc=10,
|
|
|
cache_refresh_before_expiry=False, cache_locally=False))
|
|
|
|
|
|
for node in [me, that_guy]:
|
|
@@ -285,19 +291,15 @@ def test_dht_node():
|
|
|
|
|
|
for proc in processes:
|
|
|
proc.terminate()
|
|
|
+ # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
|
|
|
+ loop.run_until_complete(asyncio.wait([node.shutdown() for node in [me, detached_node, that_guy]]))
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_dhtnode_replicas():
|
|
|
- dht_size = 20
|
|
|
- initial_peers = 3
|
|
|
num_replicas = random.randint(1, 20)
|
|
|
-
|
|
|
- peers = []
|
|
|
- for i in range(dht_size):
|
|
|
- neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(initial_peers, len(peers)))]
|
|
|
- peers.append(await DHTNode.create(initial_peers=neighbors_i, num_replicas=num_replicas))
|
|
|
+ peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
|
|
|
|
|
|
you = random.choice(peers)
|
|
|
assert await you.store('key1', 'foo', get_dht_time() + 999)
|
|
@@ -318,8 +320,8 @@ async def test_dhtnode_replicas():
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_dhtnode_caching(T=0.05):
|
|
|
node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
|
|
|
- node1 = await DHTNode.create(initial_peers=[f'localhost:{node2.port}'],
|
|
|
- cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
|
|
|
+ node1 = await DHTNode.create(initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
|
|
|
+ cache_refresh_before_expiry=5 * T, listen=False, reuse_get_requests=False)
|
|
|
await node2.store('k', [123, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
|
|
|
await node2.store('k2', [654, 'value'], expiration_time=hivemind.get_dht_time() + 7 * T)
|
|
|
await node2.store('k3', [654, 'value'], expiration_time=hivemind.get_dht_time() + 15 * T)
|
|
@@ -363,10 +365,7 @@ async def test_dhtnode_caching(T=0.05):
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_dhtnode_reuse_get():
|
|
|
- peers = []
|
|
|
- for i in range(10):
|
|
|
- neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
|
|
|
- peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=256))
|
|
|
+ peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
|
|
|
|
|
|
await asyncio.gather(
|
|
|
random.choice(peers).store('k1', 123, hivemind.get_dht_time() + 999),
|
|
@@ -396,51 +395,30 @@ async def test_dhtnode_reuse_get():
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_dhtnode_blacklist():
|
|
|
- node1 = await DHTNode.create(blacklist_time=999)
|
|
|
- node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
|
|
|
- node3 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
|
|
|
- node4 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"])
|
|
|
+ node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
|
|
|
|
|
|
assert await node2.store('abc', 123, expiration_time=hivemind.get_dht_time() + 99)
|
|
|
assert len(node2.blacklist.ban_counter) == 0
|
|
|
|
|
|
- await node3.shutdown()
|
|
|
- await node4.shutdown()
|
|
|
+ await asyncio.gather(node3.shutdown(), node4.shutdown())
|
|
|
|
|
|
assert await node2.store('def', 456, expiration_time=hivemind.get_dht_time() + 99)
|
|
|
|
|
|
- assert len(node2.blacklist.ban_counter) == 2
|
|
|
+ assert set(node2.blacklist.ban_counter.keys()) == {node3.endpoint, node4.endpoint}
|
|
|
|
|
|
- for banned_peer in node2.blacklist.ban_counter:
|
|
|
- assert any(banned_peer.endswith(str(port)) for port in [node3.port, node4.port])
|
|
|
-
|
|
|
- node3_endpoint = await node3.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
|
|
|
- node3_endpoint = replace_port(node3_endpoint, node3.port)
|
|
|
assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
|
|
|
- assert node3_endpoint in node1.blacklist
|
|
|
+ assert node3.endpoint in node1.blacklist
|
|
|
|
|
|
- node2_endpoint = await node2.protocol.get_outgoing_request_endpoint(f"{hivemind.LOCALHOST}:{node1.port}")
|
|
|
- node2_endpoint = replace_port(node2_endpoint, node2.port)
|
|
|
assert await node1.get('abc', latest=True) # force node1 to crawl dht and discover unresponsive peers
|
|
|
- assert node2_endpoint not in node1.blacklist
|
|
|
-
|
|
|
+ assert node2.endpoint not in node1.blacklist
|
|
|
|
|
|
-@pytest.mark.forked
|
|
|
-@pytest.mark.asyncio
|
|
|
-async def test_dhtnode_validate(fake_endpoint='127.0.0.721:*'):
|
|
|
- node1 = await DHTNode.create(blacklist_time=999)
|
|
|
- with pytest.raises(ValidationError):
|
|
|
- node2 = await DHTNode.create(blacklist_time=999, initial_peers=[f"{LOCALHOST}:{node1.port}"],
|
|
|
- endpoint=fake_endpoint)
|
|
|
+ await asyncio.gather(node1.shutdown(), node2.shutdown())
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_dhtnode_edge_cases():
|
|
|
- peers = []
|
|
|
- for i in range(5):
|
|
|
- neighbors_i = [f'{LOCALHOST}:{node.port}' for node in random.sample(peers, min(3, len(peers)))]
|
|
|
- peers.append(await DHTNode.create(initial_peers=neighbors_i, parallel_rpc=4))
|
|
|
+ peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
|
|
|
|
|
|
subkeys = [0, '', False, True, 'abyrvalg', 4555]
|
|
|
keys = subkeys + [()]
|
|
@@ -453,3 +431,5 @@ async def test_dhtnode_edge_cases():
|
|
|
assert stored is not None
|
|
|
assert subkey in stored.value
|
|
|
assert stored.value[subkey].value == value
|
|
|
+
|
|
|
+ await asyncio.wait([node.shutdown() for node in peers])
|