|
@@ -52,22 +52,24 @@ def test_store_get_experts(n_peers=10):
|
|
def test_beam_search(
|
|
def test_beam_search(
|
|
n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
|
|
n_peers=20, total_experts=128, batch_size=32, beam_size=4, parallel_rpc=4, grid_dims=(32, 32, 32)
|
|
):
|
|
):
|
|
- dht = [hivemind.DHT(start=True)]
|
|
|
|
- initial_peers = dht[0].get_visible_maddrs()
|
|
|
|
- dht += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
|
|
|
|
|
|
+ dht_instances = [hivemind.DHT(start=True)]
|
|
|
|
+ initial_peers = dht_instances[0].get_visible_maddrs()
|
|
|
|
+ dht_instances += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
|
|
|
|
|
|
real_experts = sorted(
|
|
real_experts = sorted(
|
|
{"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
|
|
{"expert." + ".".join([str(random.randint(0, dim - 1)) for dim in grid_dims]) for _ in range(total_experts)}
|
|
)
|
|
)
|
|
for batch_start in range(0, len(real_experts), batch_size):
|
|
for batch_start in range(0, len(real_experts), batch_size):
|
|
- dht_ = random.choice(dht)
|
|
|
|
|
|
+ dht = random.choice(dht_instances)
|
|
declare_experts(
|
|
declare_experts(
|
|
- dht_,
|
|
|
|
|
|
+ dht,
|
|
real_experts[batch_start : batch_start + batch_size],
|
|
real_experts[batch_start : batch_start + batch_size],
|
|
- peer_id=dht_.peer_id,
|
|
|
|
|
|
+ peer_id=dht.peer_id,
|
|
)
|
|
)
|
|
|
|
|
|
- neighbors = sum([peer.get_visible_maddrs() for peer in random.sample(dht, min(3, len(dht)))], [])
|
|
|
|
|
|
+ neighbors = sum(
|
|
|
|
+ [peer.get_visible_maddrs() for peer in random.sample(dht_instances, min(3, len(dht_instances)))], []
|
|
|
|
+ )
|
|
you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
|
|
you = hivemind.DHT(start=True, initial_peers=neighbors, parallel_rpc=parallel_rpc)
|
|
beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
|
|
beam_search = MoEBeamSearcher(you, "expert.", grid_dims)
|
|
|
|
|