123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- import time
- import asyncio
- import multiprocessing as mp
- import random
- import heapq
- import uuid
- from functools import partial
- from itertools import chain
- from typing import Optional
- import numpy as np
- import hivemind
- from typing import List, Dict
- from hivemind import get_dht_time
- from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST, KademliaProtocol
- from hivemind.dht.protocol import LocalStorage
- def run_protocol_listener(port: int, dhtid: DHTID, started: mp.synchronize.Event,
- ping: Optional[hivemind.Endpoint] = None):
- loop = asyncio.new_event_loop()
- protocol = partial(KademliaProtocol, dhtid, bucket_size=20, depth_modulo=5, wait_timeout=5)
- listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
- transport, protocol = loop.run_until_complete(listen)
- print(f"Started peer id={protocol.node_id} port={port}", flush=True)
- if ping is not None:
- loop.run_until_complete(protocol.call_ping(ping))
- started.set()
- loop.run_forever()
- print(f"Finished peer id={protocol.node_id} port={port}", flush=True)
- def test_kademlia_protocol():
- try:
- # create the first peer
- peer1_port, peer1_id, peer1_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
- peer1_proc = mp.Process(target=run_protocol_listener, args=(peer1_port, peer1_id, peer1_started), daemon=True)
- peer1_proc.start(), peer1_started.wait()
- # create another peer that connects to the first peer
- peer2_port, peer2_id, peer2_started = hivemind.find_open_port(), DHTID.generate(), mp.Event()
- peer2_proc = mp.Process(target=run_protocol_listener, args=(peer2_port, peer2_id, peer2_started),
- kwargs={'ping': ('127.0.0.1', peer1_port)}, daemon=True)
- peer2_proc.start(), peer2_started.wait()
- port = hivemind.find_open_port()
- loop = asyncio.new_event_loop()
- protocol = partial(KademliaProtocol, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5)
- listen = loop.create_datagram_endpoint(protocol, local_addr=('127.0.0.1', port))
- transport, protocol = loop.run_until_complete(listen)
- print(f"Self id={protocol.node_id} port={port}", flush=True)
- assert loop.run_until_complete(protocol.call_ping(('127.0.0.1', peer1_port))) == peer1_id
- key, value, expiration = DHTID.generate(), [123, {'ololo': 'pyshpysh'}], get_dht_time() + 1e3
- assert loop.run_until_complete(protocol.call_store(('127.0.0.1', peer1_port), key, value, expiration))
- # peer 1 must know about peer 2
- nodes_found = loop.run_until_complete(
- protocol.call_find_node(('127.0.0.1', peer1_port), key))
- (recv_id, recv_endpoint) = next(iter(nodes_found.items()))
- assert recv_id == peer2_id and recv_endpoint == ('127.0.0.1', peer2_port), \
- f"expected id={peer2_id}, port={('127.0.0.1', peer2_port)} but got {recv_id}, {recv_endpoint}"
- # peer 2 must know about peer 1
- nodes_found_2 = loop.run_until_complete(protocol.call_find_node(('127.0.0.1', peer2_port), key))
- (recv_id, recv_endpoint) = next(iter(nodes_found_2.items()))
- assert recv_id == peer1_id and recv_endpoint == ('127.0.0.1', peer1_port), \
- f"expected id={peer1_id}, port={('127.0.0.1', peer1_port)} but got {recv_id}, {recv_endpoint}"
- recv_value, recv_expiration, recv_peers = loop.run_until_complete(
- protocol.call_find_value(('127.0.0.1', peer1_port), key))
- assert recv_value == value and recv_expiration == expiration, "call_find_value expected " \
- f"{value} (expires by {expiration}) but got {recv_value} (expires by {recv_expiration})"
- print(recv_peers, nodes_found)
- assert recv_peers == nodes_found, "call_find_value must return the same peers as call_find_node"
- print("Kademlia test finished sucessfully!")
- finally:
- peer1_proc.terminate()
- peer2_proc.terminate()
- def run_node(node_id, port, peers, status_pipe: mp.Pipe):
- if asyncio.get_event_loop().is_running():
- asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
- asyncio.set_event_loop(asyncio.new_event_loop())
- try:
- node = DHTNode(node_id, port, initial_peers=peers)
- status_pipe.send('STARTED')
- while True:
- asyncio.get_event_loop().run_forever()
- except BaseException as e:
- status_pipe.send(e) # report exception to master
- if not isinstance(e, OSError):
- raise e
- def test_dht():
- # create dht with 50 nodes + your 51-st node
- dht: Dict[Endpoint, DHTID] = {}
- processes: List[mp.Process] = []
- port_fails, max_port_fails = 0, 10
- while len(dht) < 50:
- node_id = DHTID.generate()
- peers = random.sample(dht.keys(), min(len(dht), 5))
- port = hivemind.find_open_port()
- pipe_recv, pipe_send = mp.Pipe(duplex=False)
- proc = mp.Process(target=run_node, args=(node_id, port, peers, pipe_send), daemon=True)
- proc.start()
- status = pipe_recv.recv()
- if status == 'STARTED':
- processes.append(proc)
- dht[(LOCALHOST, port)] = node_id
- else:
- assert isinstance(status, BaseException)
- proc.terminate()
- if isinstance(status, OSError): # port already in use. It just happens sometimes.
- port_fails += 1
- if port_fails > max_port_fails:
- raise OSError("Too many 'Address already in use' errors.")
- else:
- raise ValueError(f"Failed to create node due to an error {status}, see traceback above")
- loop = asyncio.get_event_loop()
- me = hivemind.dht.node.DHTNode(initial_peers=random.sample(peers, 5), port=0) # port=0 means os-specified port
- # test 1: find self
- nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=me.node_id, k_nearest=1))
- assert len(nearest) == 1 and nearest[me.node_id] == (LOCALHOST, me.port)
- # test 2: find others
- for i in range(10):
- ref_endpoint, query_id = random.choice(list(dht.items()))
- nearest = loop.run_until_complete(me.find_nearest_nodes(query_id=query_id, k_nearest=1))
- assert len(nearest) == 1 and next(iter(nearest.items())) == (query_id, ref_endpoint)
- # test 3: find neighbors to random nodes
- accuracy_numerator = accuracy_denominator = 0 # top-1 nearest neighbor accuracy
- jaccard_numerator = jaccard_denominator = 0 # jaccard similarity aka intersection over union
- all_node_ids = list(dht.values())
- for i in range(100):
- query_id = DHTID.generate()
- k_nearest = random.randint(1, 20)
- exclude_self = random.random() > 0.5
- nearest = loop.run_until_complete(
- me.find_nearest_nodes(query_id=query_id, k_nearest=k_nearest, exclude_self=exclude_self))
- nearest_nodes = list(nearest) # keys from ordered dict
- assert len(nearest_nodes) == k_nearest, "beam search must return exactly k_nearest results"
- assert me.node_id not in nearest_nodes or not exclude_self, "if exclude, results should not contain own node id"
- assert np.all(np.diff(query_id.xor_distance(nearest_nodes)) >= 0), "results must be sorted by distance"
- ref_nearest = heapq.nsmallest(k_nearest + 1, all_node_ids, key=query_id.xor_distance)
- if exclude_self and me.node_id in ref_nearest:
- ref_nearest.remove(me.node_id)
- if len(ref_nearest) > k_nearest:
- ref_nearest.pop()
- accuracy_numerator += nearest_nodes[0] == ref_nearest[0]
- accuracy_denominator += 1
- jaccard_numerator += len(set.intersection(set(nearest_nodes), set(ref_nearest)))
- jaccard_denominator += k_nearest
- accuracy = accuracy_numerator / accuracy_denominator
- print("Top-1 accuracy:", accuracy) # should be 98-100%
- jaccard_index = jaccard_numerator / jaccard_denominator
- print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
- assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
- assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
- # test 4: find all nodes
- nearest = loop.run_until_complete(
- me.find_nearest_nodes(query_id=DHTID.generate(), k_nearest=len(dht) + 100))
- assert len(nearest) == len(dht) + 1
- assert len(set.difference(set(nearest.keys()), set(all_node_ids) | {me.node_id})) == 0
- # test 5: node without peers
- other_node = hivemind.dht.node.DHTNode()
- nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate()))
- assert len(nearest) == 1 and nearest[other_node.node_id] == (LOCALHOST, other_node.port)
- nearest = loop.run_until_complete(other_node.find_nearest_nodes(DHTID.generate(), exclude_self=True))
- assert len(nearest) == 0
- # test 6 store and get value
- true_time = get_dht_time() + 1200
- assert loop.run_until_complete(me.store("mykey", ["Value", 10], true_time))
- val, expiration_time = loop.run_until_complete(me.get("mykey"))
- assert expiration_time == true_time, "Wrong time"
- assert val == ["Value", 10], "Wrong value"
- # terminate remaining processes
- for proc in processes:
- proc.terminate()
- def test_hivemind_dht():
- peers = [hivemind.dht.DHT(start=True)]
- for i in range(10):
- neighbors_i = [('localhost', node.port) for node in random.sample(peers, min(3, len(peers)))]
- peers.append(hivemind.DHT(*neighbors_i, start=True))
- you: hivemind.dht.DHT = random.choice(peers)
- theguyshetoldyounottoworryabout: hivemind.dht.DHT = random.choice(peers)
- expert_uids = [str(uuid.uuid4()) for _ in range(110)]
- batch_size = 10
- for batch_start in range(0, len(expert_uids), batch_size):
- you.declare_experts(expert_uids[batch_start: batch_start + batch_size], 'localhost', 1234)
- found = theguyshetoldyounottoworryabout.get_experts(random.sample(expert_uids, 5) + ['foo', 'bar'])
- assert all(res is not None for res in found[:-2]), "Could not find some existing experts"
- assert all(res is None for res in found[-2:]), "Found non-existing experts"
- that_guys_expert, that_guys_port = str(uuid.uuid4()), random.randint(1000, 9999)
- theguyshetoldyounottoworryabout.declare_experts([that_guys_expert], 'that_host', that_guys_port)
- you_notfound, you_found = you.get_experts(['foobar', that_guys_expert])
- assert isinstance(you_found, hivemind.RemoteExpert)
- assert you_found.host == 'that_host', you_found.port == that_guys_port
- # test first_k_active
- assert theguyshetoldyounottoworryabout.first_k_active(expert_uids, k=10) == expert_uids[:10]
- some_permuted_experts = random.sample(expert_uids, k=32)
- assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=32) == some_permuted_experts
- assert theguyshetoldyounottoworryabout.first_k_active(some_permuted_experts, k=1) == some_permuted_experts[:1]
- fake_and_real_experts = list(chain(*zip(
- [str(uuid.uuid4()) for _ in some_permuted_experts], some_permuted_experts)))
- assert theguyshetoldyounottoworryabout.first_k_active(fake_and_real_experts, k=9) == some_permuted_experts[:9]
- for peer in peers:
- peer.shutdown()
- def test_store():
- d = LocalStorage()
- d.store("key", "val", get_dht_time() + 10)
- assert d.get("key")[0] == "val", "Wrong value"
- print("Test store passed")
- def test_get_expired():
- d = LocalStorage()
- d.store("key", "val", get_dht_time() + 1)
- time.sleep(2)
- assert d.get("key") == (None, None), "Expired value must be deleted"
- print("Test get expired passed")
- def test_get_empty():
- d = LocalStorage()
- assert d.get("key") == (None, None), "Expired value must be deleted"
- print("Test get expired passed")
- def test_change_expiration_time():
- d = LocalStorage()
- d.store("key", "val1", get_dht_time() + 2)
- d.store("key", "val2", get_dht_time() + 200)
- time.sleep(4)
- assert d.get("key")[0] == "val2", "Value must be changed, but still kept in table"
- print("Test change expiration time passed")
- def test_maxsize_cache():
- d = LocalStorage(maxsize=1)
- d.store("key1", "val1", get_dht_time() + 1)
- d.store("key2", "val2", get_dht_time() + 200)
- assert d.get("key2")[0] == "val2", "Value with bigger exp. time must be kept"
- assert d.get("key1")[0] is None, "Value with less exp time, must be deleted"
|