|
@@ -1,11 +1,15 @@
|
|
|
import argparse
|
|
|
+import asyncio
|
|
|
import random
|
|
|
import time
|
|
|
+import uuid
|
|
|
+from logging import shutdown
|
|
|
+from typing import Tuple
|
|
|
|
|
|
+import numpy as np
|
|
|
from tqdm import trange
|
|
|
|
|
|
import hivemind
|
|
|
-from hivemind.moe.server import declare_experts, get_experts
|
|
|
from hivemind.utils.limits import increase_file_limit
|
|
|
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
|
|
|
|
|
@@ -13,23 +17,116 @@ use_hivemind_log_handler("in_root_logger")
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
-def random_endpoint() -> hivemind.Endpoint:
|
|
|
- return (
|
|
|
- f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
|
|
|
- f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
|
|
|
- )
|
|
|
+class NodeKiller:
|
|
|
+ """Auxiliary class that kills dht nodes over a pre-defined schedule"""
|
|
|
+
|
|
|
+ def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
|
|
|
+ self.shutdown_peers = set(shutdown_peers)
|
|
|
+ self.shutdown_timestamps = shutdown_timestamps
|
|
|
+ self.current_iter = 0
|
|
|
+ self.timestamp_iter = 0
|
|
|
+ self.lock = asyncio.Lock()
|
|
|
+
|
|
|
+ async def check_and_kill(self):
|
|
|
+ async with self.lock:
|
|
|
+ if (
|
|
|
+ self.shutdown_timestamps != None
|
|
|
+ and self.timestamp_iter < len(self.shutdown_timestamps)
|
|
|
+ and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
|
|
|
+ ):
|
|
|
+ shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
|
|
|
+ shutdown_peer.shutdown()
|
|
|
+ self.shutdown_peers.remove(shutdown_peer)
|
|
|
+ self.timestamp_iter += 1
|
|
|
+ self.current_iter += 1
|
|
|
+
|
|
|
+
|
|
|
+async def store_and_get_task(
|
|
|
+ peers: list,
|
|
|
+ total_num_rounds: int,
|
|
|
+ num_store_peers: int,
|
|
|
+ num_get_peers: int,
|
|
|
+ wait_after_iteration: float,
|
|
|
+ delay: float,
|
|
|
+ expiration: float,
|
|
|
+ latest: bool,
|
|
|
+ node_killer: NodeKiller,
|
|
|
+) -> Tuple[list, list, list, list, int, int]:
|
|
|
+ """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
|
|
|
+
|
|
|
+ total_stores = total_gets = 0
|
|
|
+ successful_stores = []
|
|
|
+ successful_gets = []
|
|
|
+ store_times = []
|
|
|
+ get_times = []
|
|
|
+
|
|
|
+ for _ in range(total_num_rounds):
|
|
|
+ key = uuid.uuid4().hex
|
|
|
+
|
|
|
+ store_start = time.perf_counter()
|
|
|
+ store_peers = random.sample(peers, min(num_store_peers, len(peers)))
|
|
|
+ store_subkeys = [uuid.uuid4().hex for _ in store_peers]
|
|
|
+ store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
|
|
|
+ store_tasks = [
|
|
|
+ peer.store(
|
|
|
+ key,
|
|
|
+ subkey=subkey,
|
|
|
+ value=store_values[subkey],
|
|
|
+ expiration_time=hivemind.get_dht_time() + expiration,
|
|
|
+ return_future=True,
|
|
|
+ )
|
|
|
+ for peer, subkey in zip(store_peers, store_subkeys)
|
|
|
+ ]
|
|
|
+ store_result = await asyncio.gather(*store_tasks)
|
|
|
+ await node_killer.check_and_kill()
|
|
|
+
|
|
|
+ store_times.append(time.perf_counter() - store_start)
|
|
|
+
|
|
|
+ total_stores += len(store_result)
|
|
|
+ successful_stores_per_iter = sum(store_result)
|
|
|
+ successful_stores.append(successful_stores_per_iter)
|
|
|
+ await asyncio.sleep(delay)
|
|
|
+
|
|
|
+ get_start = time.perf_counter()
|
|
|
+ get_peers = random.sample(peers, min(num_get_peers, len(peers)))
|
|
|
+ get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
|
|
|
+ get_result = await asyncio.gather(*get_tasks)
|
|
|
+ get_times.append(time.perf_counter() - get_start)
|
|
|
+
|
|
|
+ successful_gets_per_iter = 0
|
|
|
+
|
|
|
+ total_gets += len(get_result)
|
|
|
+ for result in get_result:
|
|
|
+ if result != None:
|
|
|
+ attendees, expiration = result
|
|
|
+ if len(attendees.keys()) == successful_stores_per_iter:
|
|
|
+ get_ok = True
|
|
|
+ for key in attendees:
|
|
|
+ if attendees[key][0] != store_values[key]:
|
|
|
+ get_ok = False
|
|
|
+ break
|
|
|
+ successful_gets_per_iter += get_ok
|
|
|
|
|
|
+ successful_gets.append(successful_gets_per_iter)
|
|
|
+ await asyncio.sleep(wait_after_iteration)
|
|
|
|
|
|
-def benchmark_dht(
|
|
|
+ return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
|
|
|
+
|
|
|
+
|
|
|
+async def benchmark_dht(
|
|
|
num_peers: int,
|
|
|
initial_peers: int,
|
|
|
- num_experts: int,
|
|
|
- expert_batch_size: int,
|
|
|
random_seed: int,
|
|
|
- wait_after_request: float,
|
|
|
- wait_before_read: float,
|
|
|
+ num_threads: int,
|
|
|
+ total_num_rounds: int,
|
|
|
+ num_store_peers: int,
|
|
|
+ num_get_peers: int,
|
|
|
+ wait_after_iteration: float,
|
|
|
+ delay: float,
|
|
|
wait_timeout: float,
|
|
|
expiration: float,
|
|
|
+ latest: bool,
|
|
|
+ failure_rate: float,
|
|
|
):
|
|
|
random.seed(random_seed)
|
|
|
|
|
@@ -42,88 +139,92 @@ def benchmark_dht(
|
|
|
peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
|
|
|
peers.append(peer)
|
|
|
|
|
|
- store_peer, get_peer = peers[-2:]
|
|
|
-
|
|
|
- expert_uids = list(
|
|
|
- set(
|
|
|
- f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
|
|
|
- for _ in range(num_experts)
|
|
|
- )
|
|
|
- )
|
|
|
- logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
|
|
|
- random.shuffle(expert_uids)
|
|
|
-
|
|
|
- logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
|
|
|
- successful_stores = total_stores = total_store_time = 0
|
|
|
benchmark_started = time.perf_counter()
|
|
|
- endpoints = []
|
|
|
-
|
|
|
- for start in trange(0, num_experts, expert_batch_size):
|
|
|
- store_start = time.perf_counter()
|
|
|
- endpoints.append(random_endpoint())
|
|
|
- store_ok = declare_experts(
|
|
|
- store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
|
|
|
+ logger.info("Creating store and get tasks...")
|
|
|
+ shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
|
|
|
+ assert len(shutdown_peers) != len(peers)
|
|
|
+ remaining_peers = list(set(peers) - set(shutdown_peers))
|
|
|
+ shutdown_timestamps = random.sample(
|
|
|
+ range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
|
|
|
+ )
|
|
|
+ shutdown_timestamps.sort()
|
|
|
+ node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
|
|
|
+ task_list = [
|
|
|
+ asyncio.create_task(
|
|
|
+ store_and_get_task(
|
|
|
+ remaining_peers,
|
|
|
+ total_num_rounds,
|
|
|
+ num_store_peers,
|
|
|
+ num_get_peers,
|
|
|
+ wait_after_iteration,
|
|
|
+ delay,
|
|
|
+ expiration,
|
|
|
+ latest,
|
|
|
+ node_killer,
|
|
|
+ )
|
|
|
)
|
|
|
- successes = store_ok.values()
|
|
|
- total_store_time += time.perf_counter() - store_start
|
|
|
-
|
|
|
- total_stores += len(successes)
|
|
|
- successful_stores += sum(successes)
|
|
|
- time.sleep(wait_after_request)
|
|
|
+ for _ in trange(num_threads)
|
|
|
+ ]
|
|
|
+
|
|
|
+ store_and_get_result = await asyncio.gather(*task_list)
|
|
|
+ benchmark_total_time = time.perf_counter() - benchmark_started
|
|
|
+ total_store_times = []
|
|
|
+ total_get_times = []
|
|
|
+ total_successful_stores = []
|
|
|
+ total_successful_gets = []
|
|
|
+ total_stores = total_gets = 0
|
|
|
+ for result in store_and_get_result:
|
|
|
+ store_times, get_times, successful_stores, successful_gets, stores, gets = result
|
|
|
+
|
|
|
+ total_store_times.extend(store_times)
|
|
|
+ total_get_times.extend(get_times)
|
|
|
+ total_successful_stores.extend(successful_stores)
|
|
|
+ total_successful_gets.extend(successful_gets)
|
|
|
+ total_stores += stores
|
|
|
+ total_gets += gets
|
|
|
|
|
|
+ alive_peers = [peer.is_alive() for peer in peers]
|
|
|
logger.info(
|
|
|
- f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
|
|
|
+ f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
|
|
|
+ + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
|
|
|
)
|
|
|
- logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
|
|
|
- time.sleep(wait_before_read)
|
|
|
-
|
|
|
- if time.perf_counter() - benchmark_started > expiration:
|
|
|
- logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
|
|
|
-
|
|
|
- successful_gets = total_get_time = 0
|
|
|
-
|
|
|
- for start in trange(0, len(expert_uids), expert_batch_size):
|
|
|
- get_start = time.perf_counter()
|
|
|
- get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
|
|
|
- total_get_time += time.perf_counter() - get_start
|
|
|
-
|
|
|
- for i, expert in enumerate(get_result):
|
|
|
- if (
|
|
|
- expert is not None
|
|
|
- and expert.uid == expert_uids[start + i]
|
|
|
- and expert.endpoint == endpoints[start // expert_batch_size]
|
|
|
- ):
|
|
|
- successful_gets += 1
|
|
|
-
|
|
|
- if time.perf_counter() - benchmark_started > expiration:
|
|
|
- logger.warning(
|
|
|
- "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
|
|
|
- )
|
|
|
-
|
|
|
logger.info(
|
|
|
- f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
|
|
|
+ f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
|
|
|
+ + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
|
|
|
+ )
|
|
|
+ logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
|
|
|
+ logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
|
|
|
+ logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
|
|
|
+ logger.info(
|
|
|
+ "Store success rate: "
|
|
|
+ + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
|
|
|
+ )
|
|
|
+ logger.info(
|
|
|
+ "Get success rate: "
|
|
|
+ + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
|
|
|
)
|
|
|
- logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
|
|
|
-
|
|
|
- alive_peers = [peer.is_alive() for peer in peers]
|
|
|
logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser()
|
|
|
- parser.add_argument("--num_peers", type=int, default=32, required=False)
|
|
|
- parser.add_argument("--initial_peers", type=int, default=1, required=False)
|
|
|
- parser.add_argument("--num_experts", type=int, default=256, required=False)
|
|
|
- parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
|
|
|
- parser.add_argument("--expiration", type=float, default=300, required=False)
|
|
|
- parser.add_argument("--wait_after_request", type=float, default=0, required=False)
|
|
|
- parser.add_argument("--wait_before_read", type=float, default=0, required=False)
|
|
|
+ parser.add_argument("--num_peers", type=int, default=16, required=False)
|
|
|
+ parser.add_argument("--initial_peers", type=int, default=4, required=False)
|
|
|
+ parser.add_argument("--random_seed", type=int, default=30, required=False)
|
|
|
+ parser.add_argument("--num_threads", type=int, default=10, required=False)
|
|
|
+ parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
|
|
|
+ parser.add_argument("--num_store_peers", type=int, default=8, required=False)
|
|
|
+ parser.add_argument("--num_get_peers", type=int, default=8, required=False)
|
|
|
+ parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
|
|
|
+ parser.add_argument("--delay", type=float, default=0, required=False)
|
|
|
parser.add_argument("--wait_timeout", type=float, default=5, required=False)
|
|
|
- parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
|
|
|
+ parser.add_argument("--expiration", type=float, default=300, required=False)
|
|
|
+ parser.add_argument("--latest", type=bool, default=True, required=False)
|
|
|
+ parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
|
|
|
parser.add_argument("--increase_file_limit", action="store_true")
|
|
|
args = vars(parser.parse_args())
|
|
|
|
|
|
if args.pop("increase_file_limit", False):
|
|
|
increase_file_limit()
|
|
|
|
|
|
- benchmark_dht(**args)
|
|
|
+ asyncio.run(benchmark_dht(**args))
|