import time import asyncio import multiprocessing as mp import random import heapq import uuid from functools import partial from itertools import chain from typing import Optional import numpy as np import hivemind from typing import List, Dict from hivemind import get_dht_time from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol from hivemind.dht.protocol import LocalStorage def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event, ping: Optional[hivemind.Endpoint] = None): loop = asyncio.new_event_loop() protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128) listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port)) transport, protocol = loop.run_until_complete(listen) print(f"Started peer id={protocol.node_id} port={port}", flush=True) if ping is not None: loop.run_until_complete(protocol.call_ping(ping)) started.set() loop.run_forever() print(f"Finished peer id={protocol.node_id} port={port}", flush=True) def test_kademlia_protocol(): try: # create the first peer peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event() peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True) peer1_proc.start(), peer1_started.wait() # create another peer that connects to the first peer peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event() peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started), kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True) peer2_proc.start(), peer2_started.wait() port = hivemind.find_open_port() loop = asyncio.new_event_loop() protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, max_concurrent_rpc=128) listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port)) transport, protocol = loop.run_until_complete(listen) print(f"Self id={protocol.node_id} port={port}", flush=True) assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3 assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration)) # peer 1 must know about peer 2 nodes_found = loop.run_until_complete( protocol.call_find_node(('127.0.0.1', peer1_port), key)) (recv_id, recv_endpoint) = next(iter(nodes_found.items())) assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \ f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}" # peer 2 must know about peer 1 nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key)) (recv_id, recv_endpoint) = next(iter(nodes_found_2.items())) assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \ f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}" recv_value, recv_expiration, recv_peers = loop.run_until_complete( protocol.call_find_value(('127.0.0.1', peer1_port), key)) assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \ f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})" print(recv_peers, nodes_found) assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node" print("Kademlia test finished sucessfully!") finally: peer1_proc.terminate() peer2_proc.terminate() def run_node(node_id, port, peers, status_pipe: mp.Pipe): if asyncio.get_event_loop().is_running(): asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop asyncio.set_event_loop(asyncio.new_event_loop()) try: node = DHTNode(node_id, port, initial_peers=peers) status_pipe.send('STARTED') while True: asyncio.get_event_loop().run_forever() except BaseException as e: status_pipe.send(e) # report exception to master if not isinstance(e, OSError): raise e def test_dht(): # create dht with 50 nodes + your 51-st node dht: Dict[Endpoint, DHTID] = {} processes: List[mp.Process] = [] port_fails, max_port_fails = 0, 10 while len(dht) < 50: node_id = DHTID.generate() peers = random.sample(dht.keys(), min(len(dht), 5)) port = hivemind.find_open_port() pipe_recv, pipe_send = mp.Pipe(duplex=False) proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True) proc.start() status = pipe_recv.recv() if status == 'STARTED': processes.append(proc) dht[(LOCALHOST, port)] = node_id else: assert isinstance(status, BaseException) proc.terminate() if isinstance(status, OSError): # port already in use. It just happens sometimes. port_fails += 1 if port_fails > max_port_fails: raise OSError("Too many 'Address already in use' errors.") else: raise ValueError(f"Failed to create node due to an error {status}, see traceback above") loop = asyncio.get_event_loop() me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0) # port=0 means os-specified port # test 1: find self nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1)) assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port) # test 2: find others for i in range(10): ref_endpoint, query_id = random.choice(list(dht.items())) nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1)) assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint) # test 3: find neighbors to random nodes accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union all_node_ids = list(dht.values()) for i in range(100): query_id = DHTID.generate() k_nearest = random.randint(1, 20) exclude_self = random.random() > 0.5 nearest = loop.run_until_complete( me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self)) nearest_nodes = list(nearest) # keys from ordered dict assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results" assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id" assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance" ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance) if exclude_self and me.node_id in ref_nearest: ref_nearest.remove(me.node_id) if len(ref_nearest) > k_nearest: ref_nearest.pop() accuracy_numerator += nearest_nodes[0] == ref_nearest[0] accuracy_denominator += 1 jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest))) jaccard_denominator += k_nearest accuracy = accuracy_numerator / accuracy_denominator print("Top-1 accuracy:", accuracy) # should be 98-100% jaccard_index = jaccard_numerator / jaccard_denominator print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100% assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})" # test 4: find all nodes nearest = loop.run_until_complete( me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100)) assert len(nearest) == len(dht) + 1 assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0 # test 5: node without peers other_node = hivemind.dht.node.DHTNode() nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate())) assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port) nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True)) assert len(nearest) == 0 # test 6 store and get value true_time = get_dht_time() + 1200 assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time)) val, expiration_time = loop.run_until_complete(me.get("mykey")) assert expiration_time == true_time, "Wrong time" assert val == ["Value", 10], "Wrong value" # terminate remaining processes for proc in processes: proc.terminate() def test_hivemind_dht(): peers = [hivemind.dht.DHT(start=True)] for i in range(10): neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))] peers.append(hivemind.DHT(*neighbors_i, start=True)) you: hivemind.dht.DHT = random.choice(peers) theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers) expert_uids = [str(uuid.uuid4()) for _ in range(110)] batch_size = 10 for batch_start in range(0, len(expert_uids), batch_size): you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234) found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar']) assert all(res is not None for res in found[:-2]), "Could not find some existing experts" assert all(res is None for res in found[-2:]), "Found non-existing experts" that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999) theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port) you_notfound, you_found = you.get_experts(['foobar', that_guys_expert]) assert isinstance(you_found, hivemind.RemoteExpert) assert you_found.host == 'that_host', you_found.port == that_guys_port # test first_k_active assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10] some_permuted_experts = random.sample(expert_uids, k=32) assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1] fake_and_real_experts = list(chain(*zip( [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts))) assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9] for peer in peers: peer.shutdown() def test_store(): d = LocalStorage() d.store("key", "val", get_dht_time() + 10) assert d.get("key")[0] == "val", "Wrong value" print("Test store passed") def test_get_expired(): d = LocalStorage() d.store("key", "val", get_dht_time() + 1) time.sleep(2) assert d.get("key") == (None, None), "Expired value must be deleted" print("Test get expired passed") def test_get_empty(): d = LocalStorage() assert d.get("key") == (None, None), "Expired value must be deleted" print("Test get expired passed") def test_change_expiration_time(): d = LocalStorage() d.store("key", "val1", get_dht_time() + 2) d.store("key", "val2", get_dht_time() + 200) time.sleep(4) assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table" print("Test change expiration time passed") def test_maxsize_cache(): d = LocalStorage(maxsize=1) d.store("key1", "val1", get_dht_time() + 1) d.store("key2", "val2", get_dht_time() + 200) assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept" assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"