123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- import argparse
- import time
- import asyncio
- import multiprocessing as mp
- import random
- import hivemind
- from typing import List, Dict
- from hivemind import get_dht_time
- from hivemind.dht.node import DHTID, Endpoint, DHTNode, LOCALHOST
- def run_benchmark_node(node_id, port, peers, ready: mp.Event, request_perod,
- expiration_time, wait_before_read, time_to_test, statistics: mp.Queue, dht_loaded: mp.Event):
- if asyncio.get_event_loop().is_running():
- asyncio.get_event_loop().stop() # if we're in jupyter, get rid of its built-in event loop
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- node = DHTNode(node_id, port, initial_peers=peers)
- await_forever = hivemind.run_forever(asyncio.get_event_loop().run_forever)
- ready.set()
- dht_loaded.wait()
- start = time.perf_counter()
- while time.perf_counter() < start + time_to_test:
- query_id = DHTID.generate()
- store_value = random.randint(0, 256)
- store_time = time.perf_counter()
- success_store = asyncio.run_coroutine_threadsafe(
- node.store(query_id, store_value, get_dht_time() + expiration_time), loop).result()
- store_time = time.perf_counter() - store_time
- if success_store:
- time.sleep(wait_before_read)
- get_time = time.perf_counter()
- get_value, get_time_expiration = asyncio.run_coroutine_threadsafe(node.get(query_id), loop).result()
- get_time = time.perf_counter() - get_time
- success_get = (get_value == store_value)
- statistics.put((success_store, store_time, success_get, get_time))
- else:
- statistics.put((success_store, store_time, None, None))
- await_forever.result() # process will exit only if event loop broke down
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--num_nodes', type=int, default=20, required=False)
- parser.add_argument('--request_perod', type=float, default=2, required=False)
- parser.add_argument('--expiration_time', type=float, default=10, required=False)
- parser.add_argument('--wait_before_read', type=float, default=1, required=False)
- parser.add_argument('--time_to_test', type=float, default=10, required=False)
- args = parser.parse_args()
- statistics = mp.Queue()
- dht: Dict[Endpoint, DHTID] = {}
- processes: List[mp.Process] = []
- num_nodes = args.num_nodes
- request_perod = args.request_perod
- expiration_time = args.expiration_time
- wait_before_read = args.wait_before_read
- time_to_test = args.time_to_test
- dht_loaded = mp.Event()
- for i in range(num_nodes):
- node_id = DHTID.generate()
- port = hivemind.find_open_port()
- peers = random.sample(dht.keys(), min(len(dht), 5))
- ready = mp.Event()
- proc = mp.Process(target=run_benchmark_node, args=(node_id, port, peers, ready, request_perod,
- expiration_time, wait_before_read, time_to_test, statistics,
- dht_loaded), daemon=True)
- proc.start()
- ready.wait()
- processes.append(proc)
- dht[(LOCALHOST, port)] = node_id
- dht_loaded.set()
- time.sleep(time_to_test)
- success_store = 0
- all_store = 0
- time_store = 0
- success_get = 0
- all_get = 0
- time_get = 0
- while not statistics.empty():
- success_store_i, store_time_i, success_get_i, get_time_i = statistics.get()
- all_store += 1
- time_store += store_time_i
- if success_store_i:
- success_store += 1
- all_get += 1
- success_get += 1 if success_get_i else 0
- time_get += get_time_i
- alive_nodes_count = 0
- loop = asyncio.new_event_loop()
- node = DHTNode(loop=loop)
- for addr, port in dht:
- if loop.run_until_complete(node.protocol.call_ping((addr, port))) is not None:
- alive_nodes_count += 1
- print("store success rate: ", success_store / all_store)
- print("mean store time: ", time_store / all_store)
- print("get success rate: ", success_get / all_get)
- print("mean get time: ", time_get / all_get)
- print("death rate: ", (num_nodes - alive_nodes_count) / num_nodes)
|