|
@@ -1,5 +1,4 @@
|
|
|
import argparse
|
|
|
-import asyncio
|
|
|
import multiprocessing as mp
|
|
|
import random
|
|
|
import sys
|
|
@@ -113,6 +112,7 @@ def benchmark_throughput(
|
|
|
for client in clients:
|
|
|
client.start()
|
|
|
|
|
|
+
|
|
|
timestamps["launched_clients"] = timestamps["began_launching_server"] = time.perf_counter()
|
|
|
|
|
|
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -129,6 +129,7 @@ def benchmark_throughput(
|
|
|
)
|
|
|
timestamps["created_experts"] = time.perf_counter()
|
|
|
|
|
|
+
|
|
|
server = hivemind.moe.Server(
|
|
|
dht=server_dht,
|
|
|
expert_backends=experts,
|
|
@@ -137,12 +138,17 @@ def benchmark_throughput(
|
|
|
)
|
|
|
server.start()
|
|
|
server.ready.wait()
|
|
|
+ print("Joining client")
|
|
|
+
|
|
|
timestamps["server_ready"] = time.perf_counter()
|
|
|
can_start.set()
|
|
|
|
|
|
for client in clients:
|
|
|
+ print("client finished")
|
|
|
client.join()
|
|
|
|
|
|
+ print("Clients joined")
|
|
|
+
|
|
|
timestamps["clients_finished"] = time.perf_counter()
|
|
|
|
|
|
except BaseException as e:
|
|
@@ -154,7 +160,9 @@ def benchmark_throughput(
|
|
|
client.terminate()
|
|
|
server.shutdown()
|
|
|
timestamps["server_shutdown_finished"] = time.perf_counter()
|
|
|
+ print("Joining server")
|
|
|
server.join()
|
|
|
+ print("Joined server")
|
|
|
|
|
|
sys.stdout.flush()
|
|
|
sys.stderr.flush()
|
|
@@ -165,37 +173,37 @@ def benchmark_throughput(
|
|
|
)
|
|
|
total_examples = batch_size * num_clients * num_batches_per_client
|
|
|
|
|
|
- logger.info("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
|
|
|
- logger.info(
|
|
|
+ print("Benchmark finished, status:" + ["Success", "Failure"][benchmarking_failed.is_set()])
|
|
|
+
|
|
|
+ print(
|
|
|
f"Server parameters: num_experts={num_experts}, num_handlers={num_handlers}, "
|
|
|
f"max_batch_size={max_batch_size}, expert_cls={expert_cls}, hid_dim={hid_dim}, device={device}"
|
|
|
)
|
|
|
- logger.info(
|
|
|
+ print(
|
|
|
f"Client parameters: num_clients={num_clients}, num_batches_per_client={num_batches_per_client}, "
|
|
|
f"batch_size={batch_size}, backprop={backprop}"
|
|
|
)
|
|
|
|
|
|
- logger.info("Results: ")
|
|
|
- logger.info(
|
|
|
+ print("Results: ")
|
|
|
+ print(
|
|
|
f"\tServer startup took {time_between('began_launching_server', 'server_ready') :.3f} s. "
|
|
|
f"({time_between('began_launching_server', 'created_experts') :.3f} s. experts + "
|
|
|
f"{time_between('created_experts', 'server_ready') :.3f} s. networking)"
|
|
|
)
|
|
|
- logger.info(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
|
|
|
- logger.info(
|
|
|
+ print(f"\tProcessed {total_examples} examples in {time_between('server_ready', 'clients_finished') :.3f}")
|
|
|
+ print(
|
|
|
f"\tThroughput for {'forward + backward' if backprop else 'forward'} passes: "
|
|
|
f"{total_examples / time_between('server_ready', 'clients_finished') :.3f} samples / s."
|
|
|
)
|
|
|
- logger.info(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
|
|
|
+ print(f"\tBenchmarking took {time_between('started', 'server_shutdown_finished') :.3f} s.")
|
|
|
if benchmarking_failed.is_set():
|
|
|
- logger.info("Note: benchmark code failed, timing/memory results only indicate time till failure!")
|
|
|
+ print("Note: benchmark code failed, timing/memory results only indicate time till failure!")
|
|
|
print_device_info(device)
|
|
|
sys.stdout.flush()
|
|
|
sys.stderr.flush()
|
|
|
|
|
|
assert not benchmarking_failed.is_set()
|
|
|
|
|
|
-
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser()
|
|
|
parser.add_argument("--preset", type=str, default="default", required=False)
|