|
@@ -30,23 +30,20 @@ def print_device_info(device=None):
|
|
|
|
|
|
|
|
|
def client_process(
|
|
|
- can_start,
|
|
|
- benchmarking_failed,
|
|
|
- server_peer_info,
|
|
|
- num_experts,
|
|
|
- batch_size,
|
|
|
- hid_dim,
|
|
|
- num_batches,
|
|
|
- backprop=True,
|
|
|
+ can_start,
|
|
|
+ benchmarking_failed,
|
|
|
+ server_peer_info,
|
|
|
+ num_experts,
|
|
|
+ batch_size,
|
|
|
+ hid_dim,
|
|
|
+ num_batches,
|
|
|
+ backprop=True,
|
|
|
) -> None:
|
|
|
torch.set_num_threads(1)
|
|
|
can_start.wait()
|
|
|
|
|
|
- p2p = hivemind.moe.client.expert._RemoteModuleCall.run_coroutine(
|
|
|
- hivemind.P2P.create()
|
|
|
- )
|
|
|
experts = [
|
|
|
- hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info, p2p=p2p) for i in range(num_experts)
|
|
|
+ hivemind.RemoteExpert(f"expert.{i}", server_peer_info=server_peer_info) for i in range(num_experts)
|
|
|
]
|
|
|
|
|
|
try:
|
|
@@ -62,21 +59,21 @@ def client_process(
|
|
|
|
|
|
|
|
|
def benchmark_throughput(
|
|
|
- num_experts=16,
|
|
|
- num_handlers=None,
|
|
|
- num_clients=128,
|
|
|
- num_batches_per_client=16,
|
|
|
- expert_cls="ffn",
|
|
|
- hid_dim=1024,
|
|
|
- batch_size=2048,
|
|
|
- max_batch_size=None,
|
|
|
- backprop=True,
|
|
|
- device=None,
|
|
|
+ num_experts=16,
|
|
|
+ num_handlers=None,
|
|
|
+ num_clients=128,
|
|
|
+ num_batches_per_client=16,
|
|
|
+ expert_cls="ffn",
|
|
|
+ hid_dim=1024,
|
|
|
+ batch_size=2048,
|
|
|
+ max_batch_size=None,
|
|
|
+ backprop=True,
|
|
|
+ device=None,
|
|
|
):
|
|
|
assert (
|
|
|
- not hasattr(torch.cuda, "is_initialized")
|
|
|
- or not torch.cuda.is_initialized()
|
|
|
- or torch.device(device) == torch.device("cpu")
|
|
|
+ not hasattr(torch.cuda, "is_initialized")
|
|
|
+ or not torch.cuda.is_initialized()
|
|
|
+ or torch.device(device) == torch.device("cpu")
|
|
|
)
|
|
|
assert expert_cls in layers.name_to_block
|
|
|
max_batch_size = max_batch_size or batch_size * 4
|
|
@@ -245,7 +242,8 @@ if __name__ == "__main__":
|
|
|
)
|
|
|
elif args.preset == "minimalistic":
|
|
|
benchmark_throughput(
|
|
|
- num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client
|
|
|
+ num_experts=1, num_clients=1, num_handlers=1, num_batches_per_client=args.num_batches_per_client,
|
|
|
+ batch_size=512,
|
|
|
)
|
|
|
elif args.preset == "nop":
|
|
|
benchmark_throughput(expert_cls="nop", backprop=False, num_batches_per_client=args.num_batches_per_client)
|