|
@@ -40,7 +40,9 @@ def client_process(
|
|
|
) -> None:
|
|
|
torch.set_num_threads(1)
|
|
|
can_start.wait()
|
|
|
- experts = [hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)]
|
|
|
+ experts = [
|
|
|
+ hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
|
|
|
+ ]
|
|
|
|
|
|
try:
|
|
|
dummy_batch = torch.randn(batch_size, hid_dim)
|
|
@@ -61,7 +63,7 @@ def benchmark_throughput(
|
|
|
num_batches_per_client=16,
|
|
|
expert_cls="ffn",
|
|
|
hid_dim=1024,
|
|
|
- batch_size=16,
|
|
|
+ batch_size=2048,
|
|
|
max_batch_size=None,
|
|
|
backprop=True,
|
|
|
device=None,
|