123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- import asyncio
- import heapq
- import random
- from itertools import product
- import numpy as np
- import pytest
- import hivemind
- from hivemind import get_dht_time
- from hivemind.dht.node import DHTID, DHTNode
- from hivemind.utils.logging import get_logger
- from test_utils.dht_swarms import launch_star_shaped_swarm, launch_swarm_in_separate_processes
- logger = get_logger(__name__)
- # note: we run network-related tests in a separate process to re-initialize all global states from scratch
- # this helps us avoid undesirable gRPC side-effects (e.g. segfaults) when running multiple tests in sequence
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dht_node(
- n_peers: int = 20, n_sequential_peers: int = 5, parallel_rpc: int = 10, bucket_size: int = 5, num_replicas: int = 3
- ):
- # step A: create a swarm of 50 dht nodes in separate processes
- # (first 5 created sequentially, others created in parallel)
- processes, dht, swarm_maddrs = launch_swarm_in_separate_processes(
- n_peers=n_peers, n_sequential_peers=n_sequential_peers, bucket_size=bucket_size, num_replicas=num_replicas
- )
- # step B: run 51-st node in this process
- initial_peers = random.choice(swarm_maddrs)
- me = await DHTNode.create(
- initial_peers=initial_peers,
- parallel_rpc=parallel_rpc,
- bucket_size=bucket_size,
- num_replicas=num_replicas,
- cache_refresh_before_expiry=False,
- )
- # test 1: find self
- nearest = (await me.find_nearest_nodes([me.node_id], k_nearest=1))[me.node_id]
- assert len(nearest) == 1 and nearest[me.node_id] == me.peer_id
- # test 2: find others
- for _ in range(10):
- ref_peer_id, query_id = random.choice(list(dht.items()))
- nearest = (await me.find_nearest_nodes([query_id], k_nearest=1))[query_id]
- assert len(nearest) == 1
- found_node_id, found_peer_id = next(iter(nearest.items()))
- assert found_node_id == query_id and found_peer_id == ref_peer_id
- # test 3: find neighbors to random nodes
- accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
- jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
- all_node_ids = list(dht.values())
- for _ in range(20):
- query_id = DHTID.generate()
- k_nearest = random.randint(1, 10)
- exclude_self = random.random() > 0.5
- find_result = await me.find_nearest_nodes([query_id], k_nearest=k_nearest, exclude_self=exclude_self)
- nearest_nodes = list(find_result[query_id]) # keys from ordered dict
- assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
- assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results shouldn't contain self"
- assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
- ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
- if exclude_self and me.node_id in ref_nearest:
- ref_nearest.remove(me.node_id)
- if len(ref_nearest) > k_nearest:
- ref_nearest.pop()
- accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
- accuracy_denominator += 1
- jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
- jaccard_denominator += k_nearest
- accuracy = accuracy_numerator / accuracy_denominator
- logger.debug(f"Top-1 accuracy: {accuracy}") # should be 90-100%
- jaccard_index = jaccard_numerator / jaccard_denominator
- logger.debug(f"Jaccard index (intersection over union): {jaccard_index}") # should be 95-100%
- assert accuracy >= 0.8, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
- assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
- # test 4: find all nodes
- dummy = DHTID.generate()
- nearest = (await me.find_nearest_nodes([dummy], k_nearest=len(dht) + 100))[dummy]
- assert len(nearest) == len(dht) + 1
- assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
- # test 5: node without peers
- detached_node = await DHTNode.create()
- nearest = (await detached_node.find_nearest_nodes([dummy]))[dummy]
- assert len(nearest) == 1 and nearest[detached_node.node_id] == detached_node.peer_id
- nearest = (await detached_node.find_nearest_nodes([dummy], exclude_self=True))[dummy]
- assert len(nearest) == 0
- # test 6: store and get value
- true_time = get_dht_time() + 1200
- assert await me.store("mykey", ["Value", 10], true_time)
- initial_peers = random.choice(swarm_maddrs)
- that_guy = await DHTNode.create(
- initial_peers=initial_peers,
- parallel_rpc=parallel_rpc,
- cache_refresh_before_expiry=False,
- cache_locally=False,
- )
- for node in [me, that_guy]:
- val, expiration_time = await node.get("mykey")
- assert val == ["Value", 10], "Wrong value"
- assert expiration_time == true_time, f"Wrong time"
- assert not await detached_node.get("mykey")
- # test 7: bulk store and bulk get
- keys = "foo", "bar", "baz", "zzz"
- values = 3, 2, "batman", [1, 2, 3]
- store_ok = await me.store_many(keys, values, expiration_time=get_dht_time() + 999)
- assert all(store_ok.values()), "failed to store one or more keys"
- response = await me.get_many(keys[::-1])
- for key, value in zip(keys, values):
- assert key in response and response[key][0] == value
- # test 8: store dictionaries as values (with sub-keys)
- upper_key, subkey1, subkey2, subkey3 = "ololo", "k1", "k2", "k3"
- now = get_dht_time()
- assert await me.store(upper_key, subkey=subkey1, value=123, expiration_time=now + 10)
- assert await me.store(upper_key, subkey=subkey2, value=456, expiration_time=now + 20)
- for node in [that_guy, me]:
- value, time = await node.get(upper_key)
- assert isinstance(value, dict) and time == now + 20
- assert value[subkey1] == (123, now + 10)
- assert value[subkey2] == (456, now + 20)
- assert len(value) == 2
- assert not await me.store(upper_key, subkey=subkey2, value=345, expiration_time=now + 10)
- assert await me.store(upper_key, subkey=subkey2, value=567, expiration_time=now + 30)
- assert await me.store(upper_key, subkey=subkey3, value=890, expiration_time=now + 50)
- for node in [that_guy, me]:
- value, time = await node.get(upper_key, latest=True)
- assert isinstance(value, dict) and time == now + 50, (value, time)
- assert value[subkey1] == (123, now + 10)
- assert value[subkey2] == (567, now + 30)
- assert value[subkey3] == (890, now + 50)
- assert len(value) == 3
- for proc in processes:
- proc.terminate()
- # The nodes don't own their hivemind.p2p.P2P instances, so we shutdown them separately
- await asyncio.gather(me.shutdown(), that_guy.shutdown(), detached_node.shutdown())
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dhtnode_replicas():
- num_replicas = random.randint(1, 20)
- peers = await launch_star_shaped_swarm(n_peers=20, num_replicas=num_replicas)
- you = random.choice(peers)
- assert await you.store("key1", "foo", get_dht_time() + 999)
- actual_key1_replicas = sum(len(peer.protocol.storage) for peer in peers)
- assert num_replicas == actual_key1_replicas
- assert await you.store("key2", "bar", get_dht_time() + 999)
- total_size = sum(len(peer.protocol.storage) for peer in peers)
- actual_key2_replicas = total_size - actual_key1_replicas
- assert num_replicas == actual_key2_replicas
- assert await you.store("key2", "baz", get_dht_time() + 1000)
- assert sum(len(peer.protocol.storage) for peer in peers) == total_size, "total size should not have changed"
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dhtnode_caching(T=0.05):
- node2 = await DHTNode.create(cache_refresh_before_expiry=5 * T, reuse_get_requests=False)
- node1 = await DHTNode.create(
- initial_peers=await node2.protocol.p2p.get_visible_maddrs(),
- cache_refresh_before_expiry=5 * T,
- client_mode=True,
- reuse_get_requests=False,
- )
- await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
- await node2.store("k2", [654, "value"], expiration_time=hivemind.get_dht_time() + 7 * T)
- await node2.store("k3", [654, "value"], expiration_time=hivemind.get_dht_time() + 15 * T)
- await node1.get_many(["k", "k2", "k3", "k4"])
- assert len(node1.protocol.cache) == 3
- assert len(node1.cache_refresh_queue) == 0
- await node1.get_many(["k", "k2", "k3", "k4"])
- assert len(node1.cache_refresh_queue) == 3
- await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 12 * T)
- await asyncio.sleep(4 * T)
- await node1.get("k")
- await asyncio.sleep(1 * T)
- assert len(node1.protocol.cache) == 3
- assert len(node1.cache_refresh_queue) == 2
- await asyncio.sleep(3 * T)
- assert len(node1.cache_refresh_queue) == 1
- await asyncio.sleep(5 * T)
- assert len(node1.cache_refresh_queue) == 0
- await asyncio.sleep(5 * T)
- assert len(node1.cache_refresh_queue) == 0
- await node2.store("k", [123, "value"], expiration_time=hivemind.get_dht_time() + 10 * T)
- await node1.get("k")
- await asyncio.sleep(1 * T)
- assert len(node1.cache_refresh_queue) == 0
- await node1.get("k")
- await asyncio.sleep(1 * T)
- assert len(node1.cache_refresh_queue) == 1
- await asyncio.sleep(5 * T)
- assert len(node1.cache_refresh_queue) == 0
- await asyncio.gather(node1.shutdown(), node2.shutdown())
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dhtnode_reuse_get():
- peers = await launch_star_shaped_swarm(n_peers=10, parallel_rpc=256)
- await asyncio.gather(
- random.choice(peers).store("k1", 123, hivemind.get_dht_time() + 999),
- random.choice(peers).store("k2", 567, hivemind.get_dht_time() + 999),
- )
- you = random.choice(peers)
- futures1 = await you.get_many(["k1", "k2"], return_futures=True)
- assert len(you.pending_get_requests[DHTID.generate("k1")]) == 1
- assert len(you.pending_get_requests[DHTID.generate("k2")]) == 1
- futures2 = await you.get_many(["k2", "k3"], return_futures=True)
- assert len(you.pending_get_requests[DHTID.generate("k2")]) == 2
- await asyncio.gather(*futures1.values(), *futures2.values())
- futures3 = await you.get_many(["k3"], return_futures=True)
- assert len(you.pending_get_requests[DHTID.generate("k1")]) == 0
- assert len(you.pending_get_requests[DHTID.generate("k2")]) == 0
- assert len(you.pending_get_requests[DHTID.generate("k3")]) == 1
- assert (await futures1["k1"])[0] == 123
- assert await futures1["k2"] == await futures2["k2"] and (await futures1["k2"])[0] == 567
- assert await futures2["k3"] == await futures3["k3"] and (await futures3["k3"]) is None
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dhtnode_blacklist():
- node1, node2, node3, node4 = await launch_star_shaped_swarm(n_peers=4, blacklist_time=999)
- assert await node2.store("abc", 123, expiration_time=hivemind.get_dht_time() + 99)
- assert len(node2.blacklist.ban_counter) == 0
- await asyncio.gather(node3.shutdown(), node4.shutdown())
- assert await node2.store("def", 456, expiration_time=hivemind.get_dht_time() + 99)
- assert set(node2.blacklist.ban_counter.keys()) == {node3.peer_id, node4.peer_id}
- assert await node1.get("abc", latest=True) # force node1 to crawl dht and discover unresponsive peers
- assert node3.peer_id in node1.blacklist
- assert await node1.get("abc", latest=True) # force node1 to crawl dht and discover unresponsive peers
- assert node2.peer_id not in node1.blacklist
- await asyncio.gather(node1.shutdown(), node2.shutdown())
- @pytest.mark.forked
- @pytest.mark.asyncio
- async def test_dhtnode_edge_cases():
- peers = await launch_star_shaped_swarm(n_peers=4, parallel_rpc=4)
- subkeys = [0, "", False, True, "abyrvalg", 4555]
- keys = subkeys + [()]
- values = subkeys + [[]]
- for key, subkey, value in product(keys, subkeys, values):
- await random.choice(peers).store(
- key=key, subkey=subkey, value=value, expiration_time=hivemind.get_dht_time() + 999
- ),
- stored = await random.choice(peers).get(key=key, latest=True)
- assert stored is not None
- assert subkey in stored.value
- assert stored.value[subkey].value == value
- await asyncio.wait([node.shutdown() for node in peers])
|