Переглянути джерело

Remove `endpoint` parameter of GroupKeyManager

Aleksandr Borzunov 4 роки тому
батько
коміт
2e51140cc0

+ 2 - 2
hivemind/averaging/key_manager.py

@@ -30,7 +30,6 @@ class GroupKeyManager:
     def __init__(
         self,
         dht: DHT,
-        endpoint: Endpoint,
         prefix: str,
         initial_group_bits: Optional[str],
         target_group_size: int,
@@ -44,7 +43,8 @@ class GroupKeyManager:
             search_result = dht.get(f"{prefix}.0b", latest=True)
             initial_group_nbits = self.get_suggested_nbits(search_result) or 0
             initial_group_bits = "".join(random.choice("01") for _ in range(initial_group_nbits))
-        self.dht, self.endpoint, self.prefix, self.group_bits = dht, endpoint, prefix, initial_group_bits
+        self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
+        self.endpoint = dht.peer_id
         self.target_group_size = target_group_size
         self.insufficient_size = insufficient_size or max(1, target_group_size // 2)
         self.excessive_size = excessive_size or target_group_size * 3

+ 1 - 1
hivemind/averaging/matchmaking.py

@@ -59,7 +59,7 @@ class Matchmaking:
         self._p2p = p2p
         self.endpoint = p2p.id
         self.schema_hash = schema_hash
-        self.group_key_manager = GroupKeyManager(dht, self.endpoint, prefix, initial_group_bits, target_group_size)
+        self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
         self.target_group_size, self.min_group_size = target_group_size, min_group_size
         self.averaging_expiration, self.request_timeout = averaging_expiration, request_timeout
         self.client_mode = client_mode

+ 14 - 13
tests/test_averaging.py

@@ -18,39 +18,40 @@ from test_utils.dht_swarms import launch_dht_instances
 @pytest.mark.forked
 @pytest.mark.asyncio
 async def test_key_manager():
-    localhvost = PeerID(b"localhvost")
-    localhvost2 = PeerID(b"localhvost2")
-
+    dht = hivemind.DHT(start=True)
     key_manager = GroupKeyManager(
-        hivemind.DHT(start=True),
-        endpoint=localhvost,
+        dht,
         prefix="test_averaging",
         initial_group_bits="10110",
         target_group_size=2,
     )
+    alice = dht.peer_id
+    bob = PeerID(b"bob")
 
     t = hivemind.get_dht_time()
     key = key_manager.current_key
-    await key_manager.declare_averager(key, localhvost, expiration_time=t + 60)
-    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 60)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61)
 
     q1 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, localhvost, expiration_time=t + 66)
+    await key_manager.declare_averager(key, alice, expiration_time=t + 66)
     q2 = await key_manager.get_averagers(key, only_active=True)
 
-    await key_manager.declare_averager(key, localhvost2, expiration_time=t + 61, looking_for_group=False)
+    await key_manager.declare_averager(key, bob, expiration_time=t + 61, looking_for_group=False)
     q3 = await key_manager.get_averagers(key, only_active=True)
     q4 = await key_manager.get_averagers(key, only_active=False)
 
     q5 = await key_manager.get_averagers("nonexistent_key.0b0101", only_active=False)
 
-    assert len(q1) == 2 and (localhvost, t + 60) in q1 and (localhvost2, t + 61) in q1
-    assert len(q2) == 2 and (localhvost, t + 66) in q2 and (localhvost2, t + 61) in q2
-    assert len(q3) == 1 and (localhvost, t + 66) in q3
-    assert len(q4) == 2 and (localhvost, t + 66) in q4 and (localhvost2, t + 61) in q2
+    assert len(q1) == 2 and (alice, t + 60) in q1 and (bob, t + 61) in q1
+    assert len(q2) == 2 and (alice, t + 66) in q2 and (bob, t + 61) in q2
+    assert len(q3) == 1 and (alice, t + 66) in q3
+    assert len(q4) == 2 and (alice, t + 66) in q4 and (bob, t + 61) in q2
     assert len(q5) == 0
 
+    dht.shutdown()
+
 
 def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4