import argparse import asyncio import random import time import uuid from logging import shutdown from typing import Tuple import numpy as np from tqdm import trange import hivemind from hivemind.utils.limits import increase_file_limit from hivemind.utils.logging import get_logger, use_hivemind_log_handler use_hivemind_log_handler("in_root_logger") logger = get_logger(__name__) class NodeKiller: """Auxiliary class that kills dht nodes over a pre-defined schedule""" def __init__(self, shutdown_peers: list, shutdown_timestamps: list): self.shutdown_peers = set(shutdown_peers) self.shutdown_timestamps = shutdown_timestamps self.current_iter = 0 self.timestamp_iter = 0 self.lock = asyncio.Lock() async def check_and_kill(self): async with self.lock: if ( self.shutdown_timestamps != None and self.timestamp_iter < len(self.shutdown_timestamps) and self.current_iter == self.shutdown_timestamps[self.timestamp_iter] ): shutdown_peer = random.sample(self.shutdown_peers, 1)[0] shutdown_peer.shutdown() self.shutdown_peers.remove(shutdown_peer) self.timestamp_iter += 1 self.current_iter += 1 async def store_and_get_task( peers: list, total_num_rounds: int, num_store_peers: int, num_get_peers: int, wait_after_iteration: float, delay: float, expiration: float, latest: bool, node_killer: NodeKiller, ) -> Tuple[list, list, list, list, int, int]: """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers""" total_stores = total_gets = 0 successful_stores = [] successful_gets = [] store_times = [] get_times = [] for _ in range(total_num_rounds): key = uuid.uuid4().hex store_start = time.perf_counter() store_peers = random.sample(peers, min(num_store_peers, len(peers))) store_subkeys = [uuid.uuid4().hex for _ in store_peers] store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys} store_tasks = [ peer.store( key, subkey=subkey, value=store_values[subkey], expiration_time=hivemind.get_dht_time() + expiration, return_future=True, ) for peer, subkey in zip(store_peers, store_subkeys) ] store_result = await asyncio.gather(*store_tasks) await node_killer.check_and_kill() store_times.append(time.perf_counter() - store_start) total_stores += len(store_result) successful_stores_per_iter = sum(store_result) successful_stores.append(successful_stores_per_iter) await asyncio.sleep(delay) get_start = time.perf_counter() get_peers = random.sample(peers, min(num_get_peers, len(peers))) get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers] get_result = await asyncio.gather(*get_tasks) get_times.append(time.perf_counter() - get_start) successful_gets_per_iter = 0 total_gets += len(get_result) for result in get_result: if result != None: attendees, expiration = result if len(attendees.keys()) == successful_stores_per_iter: get_ok = True for key in attendees: if attendees[key][0] != store_values[key]: get_ok = False break successful_gets_per_iter += get_ok successful_gets.append(successful_gets_per_iter) await asyncio.sleep(wait_after_iteration) return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets async def benchmark_dht( num_peers: int, initial_peers: int, random_seed: int, num_threads: int, total_num_rounds: int, num_store_peers: int, num_get_peers: int, wait_after_iteration: float, delay: float, wait_timeout: float, expiration: float, latest: bool, failure_rate: float, ): random.seed(random_seed) logger.info("Creating peers...") peers = [] for _ in trange(num_peers): neighbors = sum( [peer.get_visible_maddrs() for peer in random.sample(peers, min(initial_peers, len(peers)))], [] ) peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout) peers.append(peer) benchmark_started = time.perf_counter() logger.info("Creating store and get tasks...") shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers)) assert len(shutdown_peers) != len(peers) remaining_peers = list(set(peers) - set(shutdown_peers)) shutdown_timestamps = random.sample( range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds) ) shutdown_timestamps.sort() node_killer = NodeKiller(shutdown_peers, shutdown_timestamps) task_list = [ asyncio.create_task( store_and_get_task( remaining_peers, total_num_rounds, num_store_peers, num_get_peers, wait_after_iteration, delay, expiration, latest, node_killer, ) ) for _ in trange(num_threads) ] store_and_get_result = await asyncio.gather(*task_list) benchmark_total_time = time.perf_counter() - benchmark_started total_store_times = [] total_get_times = [] total_successful_stores = [] total_successful_gets = [] total_stores = total_gets = 0 for result in store_and_get_result: store_times, get_times, successful_stores, successful_gets, stores, gets = result total_store_times.extend(store_times) total_get_times.extend(get_times) total_successful_stores.extend(successful_stores) total_successful_gets.extend(successful_gets) total_stores += stores total_gets += gets alive_peers = [peer.is_alive() for peer in peers] logger.info( f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) " + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})" ) logger.info( f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) " + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})" ) logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.") logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.") logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.") logger.info( "Store success rate: " + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})" ) logger.info( "Get success rate: " + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})" ) logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num_peers", type=int, default=16, required=False) parser.add_argument("--initial_peers", type=int, default=4, required=False) parser.add_argument("--random_seed", type=int, default=30, required=False) parser.add_argument("--num_threads", type=int, default=10, required=False) parser.add_argument("--total_num_rounds", type=int, default=16, required=False) parser.add_argument("--num_store_peers", type=int, default=8, required=False) parser.add_argument("--num_get_peers", type=int, default=8, required=False) parser.add_argument("--wait_after_iteration", type=float, default=0, required=False) parser.add_argument("--delay", type=float, default=0, required=False) parser.add_argument("--wait_timeout", type=float, default=5, required=False) parser.add_argument("--expiration", type=float, default=300, required=False) parser.add_argument("--latest", type=bool, default=True, required=False) parser.add_argument("--failure_rate", type=float, default=0.1, required=False) parser.add_argument("--increase_file_limit", action="store_true") args = vars(parser.parse_args()) if args.pop("increase_file_limit", False): increase_file_limit() asyncio.run(benchmark_dht(**args))