소스 검색

Fix test_averaging.py

Aleksandr Borzunov 4 년 전
부모
커밋
4f5acb5eee
2개의 변경된 파일69개의 추가작업 그리고 59개의 파일을 삭제
  1. 3 2
      hivemind/averaging/matchmaking.py
  2. 66 57
      tests/test_averaging.py

+ 3 - 2
hivemind/averaging/matchmaking.py

@@ -237,6 +237,7 @@ class Matchmaking:
         self, request: averaging_pb2.JoinRequest, _: P2PContext
     ) -> AsyncIterator[averaging_pb2.MessageFromLeader]:
         """accept or reject a join request from another averager; if accepted, run him through allreduce steps"""
+        request_endpoint = None
         try:
             async with self.lock_request_join_group:
                 reason_to_reject = self._check_reasons_to_reject(request)
@@ -449,9 +450,9 @@ class PotentialLeaders:
             if maybe_next_leader is None or self.max_assured_time <= entry.expiration_time <= self.search_end_time:
                 self.update_triggered.set()
 
-            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader) > (
+            if maybe_next_leader is None or (entry.expiration_time, maybe_next_leader.to_base58()) > (
                 self.declared_expiration_time,
-                self.endpoint,
+                self.endpoint.to_base58(),
             ):
                 await asyncio.wait(
                     {self.update_finished.wait(), self.declared_expiration.wait()}, return_when=asyncio.FIRST_COMPLETED

+ 66 - 57
tests/test_averaging.py

@@ -1,4 +1,5 @@
 import random
+from typing import List
 
 import numpy as np
 import pytest
@@ -50,6 +51,13 @@ async def test_key_manager():
     assert len(q5) == 0
 
 
+def launch_dht_instances(n_peers: int, **kwargs) -> List[hivemind.DHT]:
+    instances = [hivemind.DHT(start=True, **kwargs)]
+    initial_peers = instances[0].get_visible_maddrs()
+    instances.extend(hivemind.DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
+    return instances
+
+
 def _test_allreduce_once(n_clients, n_aux):
     n_peers = 4
     modes = (
@@ -71,14 +79,9 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
     ]
 
-    dht_root = hivemind.DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-    averagers = []
-    dhts = []
-    for tensors, mode in zip(peer_tensors, modes):
-        dht_instance = hivemind.DHT(start=True, initial_peers=initial_peers)
-        dhts.append(dht_instance)
-        averagers.append(hivemind.averaging.DecentralizedAverager(
+    dhts = launch_dht_instances(len(peer_tensors))
+    averagers = [
+        hivemind.averaging.DecentralizedAverager(
             tensors,
             dht=dht_instance,
             target_group_size=4,
@@ -87,7 +90,9 @@ def _test_allreduce_once(n_clients, n_aux):
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
             start=True,
-        ))
+        )
+        for tensors, dht_instance, mode in zip(peer_tensors, dhts, modes)
+    ]
 
     futures = []
     for averager in averagers:
@@ -103,11 +108,8 @@ def _test_allreduce_once(n_clients, n_aux):
                 for ref, our in zip(reference, averaged_tensors):
                     assert torch.allclose(ref, our, atol=1e-6)
 
-    for averager in averagers:
-        averager.shutdown()
-    for instance in dhts:
+    for instance in averagers + dhts:
         instance.shutdown()
-    dht_root.shutdown()
 
 
 @pytest.mark.forked
@@ -125,8 +127,6 @@ def test_allreduce_once_edge_cases(n_clients, n_aux):
 
 @pytest.mark.forked
 def test_allreduce_weighted(n_client_mode_peers: int = 2):
-    dht = hivemind.DHT(start=True)
-
     n_peers = 4
     client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
     random.shuffle(client_modes)
@@ -135,18 +135,21 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors2 = [torch.rand(123), torch.ones(3)]
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
+
+    dhts = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             tensors,
-            dht=dht,
+            dht=dht_instance,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=client_mode,
             start=True,
         )
-        for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
+        for tensors, dht_instance, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dhts, client_modes)
     ]
+
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
     reference = [
         (tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
@@ -165,15 +168,13 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
 
-    for averager in averagers:
-        averager.shutdown()
-    dht.shutdown()
+    for instance in averagers + dhts:
+        instance.shutdown()
 
 
 @pytest.mark.forked
 def test_allreduce_compression():
     """this test ensures that compression works correctly when multiple tensors have different compression types"""
-    dht = hivemind.DHT(start=True)
 
     tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
     tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
@@ -182,9 +183,10 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
+        dhts = launch_dht_instances(2)
         averager1 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors1],
-            dht=dht,
+            dht=dhts[0],
             compression_type=compression_type_pair,
             client_mode=True,
             target_group_size=2,
@@ -193,7 +195,7 @@ def test_allreduce_compression():
         )
         averager2 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors2],
-            dht=dht,
+            dht=dhts[1],
             compression_type=compression_type_pair,
             target_group_size=2,
             prefix="mygroup",
@@ -206,6 +208,9 @@ def test_allreduce_compression():
         with averager1.get_tensors() as averaged_tensors:
             results[compression_type_pair] = averaged_tensors
 
+        for instance in [averager1, averager2] + dhts:
+            instance.shutdown()
+
     assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
     assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
     assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
@@ -236,17 +241,17 @@ def compute_mean_std(averagers, unbiased=True):
 
 @pytest.mark.forked
 def test_allreduce_grid():
-    dht = hivemind.DHT(start=True)
+    dhts = launch_dht_instances(8)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht,
+            dht=dht_instance,
             target_group_size=2,
             prefix="mygroup",
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
             start=True,
         )
-        for i in range(8)
+        for dht_instance in dhts
     ]
 
     [means0], [stds0] = compute_mean_std(averagers)
@@ -266,25 +271,24 @@ def test_allreduce_grid():
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
-    for averager in averagers:
+    for averager in averagers + dhts:
         averager.shutdown()
-    dht.shutdown()
 
 
 @pytest.mark.forked
 def test_allgather():
-    dht = hivemind.DHT(start=True)
+    dhts = launch_dht_instances(8)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             [torch.ones(1)],
-            dht=dht,
+            dht=dht_instance,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
             initial_group_bits="000",
             start=True,
         )
-        for _ in range(8)
+        for dht_instance in dhts
     ]
 
     futures = []
@@ -304,9 +308,8 @@ def test_allgather():
         for endpoint in gathered:
             assert gathered[endpoint] == reference_metadata[endpoint]
 
-    for averager in averagers:
+    for averager in averagers + dhts:
         averager.shutdown()
-    dht.shutdown()
 
 
 def get_cost(vector_size, partitions, bandwidths):
@@ -354,11 +357,11 @@ def test_load_balancing():
 
 @pytest.mark.forked
 def test_too_few_peers():
-    dht = hivemind.DHT(start=True)
+    dhts = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht,
+            dht=dht_instance,
             target_group_size=2,
             averaging_expiration=1,
             request_timeout=0.5,
@@ -366,24 +369,23 @@ def test_too_few_peers():
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
             start=True,
         )
-        for i in range(4)
+        for i, dht_instance in enumerate(dhts)
     ]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
         assert len(future.result()) == 2
 
-    for averager in averagers:
+    for averager in averagers + dhts:
         averager.shutdown()
-    dht.shutdown()
 
 
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
-    dht = hivemind.DHT(start=True)
+    dhts = launch_dht_instances(num_peers)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht,
+            dht=dht_instance,
             target_group_size=2,
             averaging_expiration=1,
             request_timeout=0.5,
@@ -391,15 +393,14 @@ def test_overcrowded(num_peers=16):
             initial_group_bits="",
             start=True,
         )
-        for _ in range(num_peers)
+        for dht_instance in dhts
     ]
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
-    for averager in averagers:
+    for averager in averagers + dhts:
         averager.shutdown()
-    dht.shutdown()
 
 
 @pytest.mark.forked
@@ -418,22 +419,19 @@ def test_load_state_from_peers():
             num_calls += 1
             return super_metadata, super_tensors
 
-    dht_root = hivemind.DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-    dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
+    dhts = launch_dht_instances(2)
     averager1 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dht1,
+        dht=dhts[0],
         start=True,
         prefix="demo-run",
         target_group_size=2,
     )
 
-    dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
-    dht2.get("demo-run.all_averagers")
+    dhts[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dht2,
+        dht=dhts[1],
         start=True,
         prefix="demo-run",
         target_group_size=2,
@@ -462,6 +460,9 @@ def test_load_state_from_peers():
     assert num_calls == 3
     assert got_metadata == super_metadata
 
+    for instance in [averager1, averager2] + dhts:
+        instance.shutdown()
+
 
 @pytest.mark.forked
 def test_getset_bits():
@@ -481,9 +482,8 @@ def test_getset_bits():
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
 
-    dht = hivemind.DHT(start=True)
+    dhts = launch_dht_instances(2)
     common_kwargs = {
-        "dht": dht,
         "start": True,
         "prefix": "demo-run",
         "target_group_size": 2,
@@ -492,13 +492,23 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     x1 = torch.randn(n_dims, requires_grad=True)
     opt1 = torch.optim.Adam([x1], lr=0.05)
     averager1 = hivemind.averaging.TrainingAverager(
-        opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt1,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dhts[0],
+        **common_kwargs
     )
 
     x2 = torch.randn(n_dims, requires_grad=True)
     opt2 = torch.optim.Adam([x2], lr=0.05)
     averager2 = hivemind.averaging.TrainingAverager(
-        opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
+        opt2,
+        average_gradients=True,
+        average_parameters=True,
+        average_opt_statistics=["exp_avg_sq"],
+        dht=dhts[1],
+        **common_kwargs
     )
     a = torch.ones(n_dims)
 
@@ -528,6 +538,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
 
-    averager1.shutdown()
-    averager2.shutdown()
-    dht.shutdown()
+    for instance in [averager1, averager2] + dhts:
+        instance.shutdown()