فهرست منبع

Make diff smaller

Aleksandr Borzunov 4 سال پیش
والد
کامیت
95574522ef
1فایلهای تغییر یافته به همراه37 افزوده شده و 37 حذف شده
  1. 37 37
      tests/test_averaging.py

+ 37 - 37
tests/test_averaging.py

@@ -74,11 +74,11 @@ def _test_allreduce_once(n_clients, n_aux):
         for i in range(len(tensors1))
     ]
 
-    dhts = launch_dht_instances(len(peer_tensors))
+    dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             tensors,
-            dht=dht_instance,
+            dht=dht,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
@@ -86,7 +86,7 @@ def _test_allreduce_once(n_clients, n_aux):
             auxiliary=mode == AveragingMode.AUX,
             start=True,
         )
-        for tensors, dht_instance, mode in zip(peer_tensors, dhts, modes)
+        for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
     ]
 
     futures = []
@@ -103,7 +103,7 @@ def _test_allreduce_once(n_clients, n_aux):
                 for ref, our in zip(reference, averaged_tensors):
                     assert torch.allclose(ref, our, atol=1e-6)
 
-    for instance in averagers + dhts:
+    for instance in averagers + dht_instances:
         instance.shutdown()
 
 
@@ -131,18 +131,18 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
     tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
     tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
 
-    dhts = launch_dht_instances(4)
+    dht_instances = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             tensors,
-            dht=dht_instance,
+            dht=dht,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
             client_mode=client_mode,
             start=True,
         )
-        for tensors, dht_instance, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dhts, client_modes)
+        for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
     ]
 
     weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
@@ -163,7 +163,7 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
             for ref, our in zip(reference, averaged_tensors):
                 assert torch.allclose(ref, our, atol=1e-6)
 
-    for instance in averagers + dhts:
+    for instance in averagers + dht_instances:
         instance.shutdown()
 
 
@@ -178,10 +178,10 @@ def test_allreduce_compression():
     FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
 
     for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
-        dhts = launch_dht_instances(2)
+        dht_instances = launch_dht_instances(2)
         averager1 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors1],
-            dht=dhts[0],
+            dht=dht_instances[0],
             compression_type=compression_type_pair,
             client_mode=True,
             target_group_size=2,
@@ -190,7 +190,7 @@ def test_allreduce_compression():
         )
         averager2 = hivemind.averaging.DecentralizedAverager(
             [x.clone() for x in tensors2],
-            dht=dhts[1],
+            dht=dht_instances[1],
             compression_type=compression_type_pair,
             target_group_size=2,
             prefix="mygroup",
@@ -203,7 +203,7 @@ def test_allreduce_compression():
         with averager1.get_tensors() as averaged_tensors:
             results[compression_type_pair] = averaged_tensors
 
-        for instance in [averager1, averager2] + dhts:
+        for instance in [averager1, averager2] + dht_instances:
             instance.shutdown()
 
     assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
@@ -236,17 +236,17 @@ def compute_mean_std(averagers, unbiased=True):
 
 @pytest.mark.forked
 def test_allreduce_grid():
-    dhts = launch_dht_instances(8)
+    dht_instances = launch_dht_instances(8)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht_instance,
+            dht=dht,
             target_group_size=2,
             prefix="mygroup",
             initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
             start=True,
         )
-        for i, dht_instance in enumerate(dhts)
+        for i, dht in enumerate(dht_instances)
     ]
 
     [means0], [stds0] = compute_mean_std(averagers)
@@ -266,24 +266,24 @@ def test_allreduce_grid():
         else:
             assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
 
-    for averager in averagers + dhts:
+    for averager in averagers + dht_instances:
         averager.shutdown()
 
 
 @pytest.mark.forked
 def test_allgather():
-    dhts = launch_dht_instances(8)
+    dht_instances = launch_dht_instances(8)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             [torch.ones(1)],
-            dht=dht_instance,
+            dht=dht,
             target_group_size=4,
             averaging_expiration=15,
             prefix="mygroup",
             initial_group_bits="000",
             start=True,
         )
-        for dht_instance in dhts
+        for dht in dht_instances
     ]
 
     futures = []
@@ -307,7 +307,7 @@ def test_allgather():
         for endpoint in gathered:
             assert gathered[endpoint] == reference_metadata[endpoint]
 
-    for averager in averagers + dhts:
+    for averager in averagers + dht_instances:
         averager.shutdown()
 
 
@@ -356,11 +356,11 @@ def test_load_balancing():
 
 @pytest.mark.forked
 def test_too_few_peers():
-    dhts = launch_dht_instances(4)
+    dht_instances = launch_dht_instances(4)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht_instance,
+            dht=dht,
             target_group_size=2,
             averaging_expiration=1,
             request_timeout=0.5,
@@ -368,23 +368,23 @@ def test_too_few_peers():
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
             start=True,
         )
-        for i, dht_instance in enumerate(dhts)
+        for i, dht in enumerate(dht_instances)
     ]
     step_futures = [averager.step(wait=False) for averager in averagers]
     for future in step_futures:
         assert len(future.result()) == 2
 
-    for averager in averagers + dhts:
+    for averager in averagers + dht_instances:
         averager.shutdown()
 
 
 @pytest.mark.forked
 def test_overcrowded(num_peers=16):
-    dhts = launch_dht_instances(num_peers)
+    dht_instances = launch_dht_instances(num_peers)
     averagers = [
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
-            dht=dht_instance,
+            dht=dht,
             target_group_size=2,
             averaging_expiration=1,
             request_timeout=0.5,
@@ -392,13 +392,13 @@ def test_overcrowded(num_peers=16):
             initial_group_bits="",
             start=True,
         )
-        for dht_instance in dhts
+        for dht in dht_instances
     ]
     for t in range(5):
         step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
         assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
 
-    for averager in averagers + dhts:
+    for averager in averagers + dht_instances:
         averager.shutdown()
 
 
@@ -418,19 +418,19 @@ def test_load_state_from_peers():
             num_calls += 1
             return super_metadata, super_tensors
 
-    dhts = launch_dht_instances(2)
+    dht_instances = launch_dht_instances(2)
     averager1 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dhts[0],
+        dht=dht_instances[0],
         start=True,
         prefix="demo-run",
         target_group_size=2,
     )
 
-    dhts[1].get("demo-run.all_averagers")
+    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
-        dht=dhts[1],
+        dht=dht_instances[1],
         start=True,
         prefix="demo-run",
         target_group_size=2,
@@ -459,7 +459,7 @@ def test_load_state_from_peers():
     assert num_calls == 3
     assert got_metadata == super_metadata
 
-    for instance in [averager1, averager2] + dhts:
+    for instance in [averager1, averager2] + dht_instances:
         instance.shutdown()
 
 
@@ -481,7 +481,7 @@ def test_getset_bits():
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)
 
-    dhts = launch_dht_instances(2)
+    dht_instances = launch_dht_instances(2)
     common_kwargs = {
         "start": True,
         "prefix": "demo-run",
@@ -495,7 +495,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         average_gradients=True,
         average_parameters=True,
         average_opt_statistics=["exp_avg_sq"],
-        dht=dhts[0],
+        dht=dht_instances[0],
         **common_kwargs
     )
 
@@ -506,7 +506,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         average_gradients=True,
         average_parameters=True,
         average_opt_statistics=["exp_avg_sq"],
-        dht=dhts[1],
+        dht=dht_instances[1],
         **common_kwargs
     )
     a = torch.ones(n_dims)
@@ -537,5 +537,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
         assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
         assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
 
-    for instance in [averager1, averager2] + dhts:
+    for instance in [averager1, averager2] + dht_instances:
         instance.shutdown()