|
@@ -40,9 +40,7 @@ def client_process(
|
|
|
) -> None:
|
|
|
torch.set_num_threads(1)
|
|
|
can_start.wait()
|
|
|
- experts = [
|
|
|
- hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
|
|
|
- ]
|
|
|
+ experts = [hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)]
|
|
|
|
|
|
try:
|
|
|
dummy_batch = torch.randn(batch_size, hid_dim)
|
|
@@ -57,9 +55,9 @@ def client_process(
|
|
|
|
|
|
|
|
|
def benchmark_throughput(
|
|
|
- num_experts=1,
|
|
|
+ num_experts=16,
|
|
|
num_handlers=None,
|
|
|
- num_clients=1,
|
|
|
+ num_clients=128,
|
|
|
num_batches_per_client=16,
|
|
|
expert_cls="ffn",
|
|
|
hid_dim=1024,
|