|
@@ -18,39 +18,40 @@ from test_utils.dht_swarms import launch_dht_instances
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.asyncio
|
|
|
async def test_key_manager():
|
|
|
- localhvost = PeerID(b"localhvost")
|
|
|
- localhvost2 = PeerID(b"localhvost2")
|
|
|
-
|
|
|
+ dht = hivemind.DHT(start=True)
|
|
|
key_manager = GroupKeyManager(
|
|
|
- hivemind.DHT(start=True),
|
|
|
- endpoint=localhvost,
|
|
|
+ dht,
|
|
|
prefix="test_averaging",
|
|
|
initial_group_bits="10110",
|
|
|
target_group_size=2,
|
|
|
)
|
|
|
+ alice = dht.peer_id
|
|
|
+ bob = PeerID(b"bob")
|
|
|
|
|
|
t = hivemind.get_dht_time()
|
|
|
key = key_manager.current_key
|
|
|
- await key_manager.declare_averager(key, localhvost, expiration_time=t + 60)
|
|
|
- await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61)
|
|
|
+ await key_manager.declare_averager(key, alice, expiration_time=t + 60)
|
|
|
+ await key_manager.declare_averager(key, bob, expiration_time=t + 61)
|
|
|
|
|
|
q1 = await key_manager.get_averagers(key, only_active=True)
|
|
|
|
|
|
- await key_manager.declare_averager(key, localhvost, expiration_time=t + 66)
|
|
|
+ await key_manager.declare_averager(key, alice, expiration_time=t + 66)
|
|
|
q2 = await key_manager.get_averagers(key, only_active=True)
|
|
|
|
|
|
- await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61, looking_for_group=False)
|
|
|
+ await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
|
|
|
q3 = await key_manager.get_averagers(key, only_active=True)
|
|
|
q4 = await key_manager.get_averagers(key, only_active=False)
|
|
|
|
|
|
q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
|
|
|
|
|
|
- assert len(q1) == 2 and (localhvost, t + 60) in q1 and (localhvost2, t + 61) in q1
|
|
|
- assert len(q2) == 2 and (localhvost, t + 66) in q2 and (localhvost2, t + 61) in q2
|
|
|
- assert len(q3) == 1 and (localhvost, t + 66) in q3
|
|
|
- assert len(q4) == 2 and (localhvost, t + 66) in q4 and (localhvost2, t + 61) in q2
|
|
|
+ assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
|
|
|
+ assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
|
|
|
+ assert len(q3) == 1 and (alice, t + 66) in q3
|
|
|
+ assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
|
|
|
assert len(q5) == 0
|
|
|
|
|
|
+ dht.shutdown()
|
|
|
+
|
|
|
|
|
|
def _test_allreduce_once(n_clients, n_aux):
|
|
|
n_peers = 4
|