Pārlūkot izejas kodu

update p2p benchmark

Denis Mazur 4 gadi atpakaļ
vecāks
revīzija
ceadddf836
1 mainītis faili ar 24 papildinājumiem un 26 dzēšanām
  1. 24 26
      benchmarks/benchmark_throughput_p2p.py

+ 24 - 26
benchmarks/benchmark_throughput_p2p.py

@@ -30,23 +30,20 @@ def print_device_info(device=None):
 
 
 def client_process(
-    can_start,
-    benchmarking_failed,
-    server_peer_info,
-    num_experts,
-    batch_size,
-    hid_dim,
-    num_batches,
-    backprop=True,
+        can_start,
+        benchmarking_failed,
+        server_peer_info,
+        num_experts,
+        batch_size,
+        hid_dim,
+        num_batches,
+        backprop=True,
 ) -> None:
     torch.set_num_threads(1)
     can_start.wait()
 
-    p2p = hivemind.moe.client.expert._RemoteModuleCall.run_coroutine(
-        hivemind.P2P.create()
-    )
     experts = [
-        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
     ]
 
     try:
@@ -62,21 +59,21 @@ def client_process(
 
 
 def benchmark_throughput(
-    num_experts=16,
-    num_handlers=None,
-    num_clients=128,
-    num_batches_per_client=16,
-    expert_cls="ffn",
-    hid_dim=1024,
-    batch_size=2048,
-    max_batch_size=None,
-    backprop=True,
-    device=None,
+        num_experts=16,
+        num_handlers=None,
+        num_clients=128,
+        num_batches_per_client=16,
+        expert_cls="ffn",
+        hid_dim=1024,
+        batch_size=2048,
+        max_batch_size=None,
+        backprop=True,
+        device=None,
 ):
     assert (
-        not hasattr(torch.cuda, "is_initialized")
-        or not torch.cuda.is_initialized()
-        or torch.device(device) == torch.device("cpu")
+            not hasattr(torch.cuda, "is_initialized")
+            or not torch.cuda.is_initialized()
+            or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in layers.name_to_block
     max_batch_size = max_batch_size or batch_size * 4
@@ -245,7 +242,8 @@ if __name__ == "__main__":
         )
     elif args.preset == "minimalistic":
         benchmark_throughput(
-            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
+            num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client,
+            batch_size=512,
         )
     elif args.preset == "nop":
         benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)