瀏覽代碼

fix benchmark

Denis Mazur 4 年之前
父節點
當前提交
f49e0daea1
共有 1 個文件被更改,包括 3 次插入1 次删除
  1. 3 1
      benchmarks/benchmark_throughput_p2p.py

+ 3 - 1
benchmarks/benchmark_throughput_p2p.py

@@ -8,6 +8,7 @@ import time
 import torch
 
 import hivemind
+from hivemind import P2P
 from hivemind.dht import DHT
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
@@ -42,8 +43,9 @@ def client_process(
     torch.set_num_threads(1)
     can_start.wait()
 
+    p2p = hivemind.moe.client.expert._RemoteModuleCall.run_coroutine(P2P.create())
     experts = [
-        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
+        hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
     ]
 
     try: