Browse Source

Move launch_dht_instances() to test_utils.py

Aleksandr Borzunov 4 years ago
parent
commit
7eb91a3b25
3 changed files with 11 additions and 10 deletions
  1. 1 7
      tests/test_averaging.py
  2. 2 3
      tests/test_dht.py
  3. 8 0
      tests/test_utils/dht_swarms.py

+ 1 - 7
tests/test_averaging.py

@@ -12,6 +12,7 @@ from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.p2p import PeerID
 from hivemind.proto.runtime_pb2 import CompressionType
+from test_utils.dht_swarms import launch_dht_instances
 
 
 @pytest.mark.forked
@@ -51,13 +52,6 @@ async def test_key_manager():
     assert len(q5) == 0
 
 
-def launch_dht_instances(n_peers: int, **kwargs) -> List[hivemind.DHT]:
-    instances = [hivemind.DHT(start=True, **kwargs)]
-    initial_peers = instances[0].get_visible_maddrs()
-    instances.extend(hivemind.DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
-    return instances
-
-
 def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     modes = (

+ 2 - 3
tests/test_dht.py

@@ -6,13 +6,12 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
+from test_utils.dht_swarms import launch_dht_instances
 
 
 @pytest.mark.forked
 def test_get_store(n_peers=10):
-    peers = [hivemind.DHT(start=True)]
-    initial_peers = peers[0].get_visible_maddrs()
-    peers += [hivemind.DHT(initial_peers=initial_peers, start=True) for _ in range(n_peers - 1)]
+    peers = launch_dht_instances(n_peers)
 
     node1, node2 = random.sample(peers, 2)
     assert node1.store("key1", "value1", expiration_time=hivemind.get_dht_time() + 30)

+ 8 - 0
tests/test_utils/dht_swarms.py

@@ -7,6 +7,7 @@ from typing import Dict, List, Tuple
 
 from multiaddr import Multiaddr
 
+from hivemind.dht import DHT
 from hivemind.dht.node import DHTID, DHTNode
 from hivemind.p2p import PeerID
 
@@ -86,3 +87,10 @@ async def launch_star_shaped_swarm(n_peers: int, **kwargs) -> List[DHTNode]:
     initial_peers = await nodes[0].get_visible_maddrs()
     nodes += await asyncio.gather(*[DHTNode.create(initial_peers=initial_peers, **kwargs) for _ in range(n_peers - 1)])
     return nodes
+
+
+def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
+    instances = [DHT(start=True, **kwargs)]
+    initial_peers = instances[0].get_visible_maddrs()
+    instances.extend(DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
+    return instances