|
@@ -1,4 +1,5 @@
|
|
|
import random
|
|
|
+from typing import List
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
@@ -50,6 +51,13 @@ async def test_key_manager():
|
|
|
assert len(q5) == 0
|
|
|
|
|
|
|
|
|
+def launch_dht_instances(n_peers: int, **kwargs) -> List[hivemind.DHT]:
|
|
|
+ instances = [hivemind.DHT(start=True, **kwargs)]
|
|
|
+ initial_peers = instances[0].get_visible_maddrs()
|
|
|
+ instances.extend(hivemind.DHT(initial_peers=initial_peers, start=True, **kwargs) for _ in range(n_peers - 1))
|
|
|
+ return instances
|
|
|
+
|
|
|
+
|
|
|
def _test_allreduce_once(n_clients, n_aux):
|
|
|
n_peers = 4
|
|
|
modes = (
|
|
@@ -71,14 +79,9 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
for i in range(len(tensors1))
|
|
|
]
|
|
|
|
|
|
- dht_root = hivemind.DHT(start=True)
|
|
|
- initial_peers = dht_root.get_visible_maddrs()
|
|
|
- averagers = []
|
|
|
- dhts = []
|
|
|
- for tensors, mode in zip(peer_tensors, modes):
|
|
|
- dht_instance = hivemind.DHT(start=True, initial_peers=initial_peers)
|
|
|
- dhts.append(dht_instance)
|
|
|
- averagers.append(hivemind.averaging.DecentralizedAverager(
|
|
|
+ dhts = launch_dht_instances(len(peer_tensors))
|
|
|
+ averagers = [
|
|
|
+ hivemind.averaging.DecentralizedAverager(
|
|
|
tensors,
|
|
|
dht=dht_instance,
|
|
|
target_group_size=4,
|
|
@@ -87,7 +90,9 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
client_mode=mode == AveragingMode.CLIENT,
|
|
|
auxiliary=mode == AveragingMode.AUX,
|
|
|
start=True,
|
|
|
- ))
|
|
|
+ )
|
|
|
+ for tensors, dht_instance, mode in zip(peer_tensors, dhts, modes)
|
|
|
+ ]
|
|
|
|
|
|
futures = []
|
|
|
for averager in averagers:
|
|
@@ -103,11 +108,8 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
for ref, our in zip(reference, averaged_tensors):
|
|
|
assert torch.allclose(ref, our, atol=1e-6)
|
|
|
|
|
|
- for averager in averagers:
|
|
|
- averager.shutdown()
|
|
|
- for instance in dhts:
|
|
|
+ for instance in averagers + dhts:
|
|
|
instance.shutdown()
|
|
|
- dht_root.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@@ -125,8 +127,6 @@ def test_allreduce_once_edge_cases(n_clients, n_aux):
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allreduce_weighted(n_client_mode_peers: int = 2):
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
-
|
|
|
n_peers = 4
|
|
|
client_modes = [True] * n_client_mode_peers + [False] * (n_peers - n_client_mode_peers)
|
|
|
random.shuffle(client_modes)
|
|
@@ -135,18 +135,21 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
|
|
|
tensors2 = [torch.rand(123), torch.ones(3)]
|
|
|
tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
|
|
|
tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
|
|
|
+
|
|
|
+ dhts = launch_dht_instances(4)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
tensors,
|
|
|
- dht=dht,
|
|
|
+ dht=dht_instance,
|
|
|
target_group_size=4,
|
|
|
averaging_expiration=15,
|
|
|
prefix="mygroup",
|
|
|
client_mode=client_mode,
|
|
|
start=True,
|
|
|
)
|
|
|
- for tensors, client_mode in zip([tensors1, tensors2, tensors3, tensors4], client_modes)
|
|
|
+ for tensors, dht_instance, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dhts, client_modes)
|
|
|
]
|
|
|
+
|
|
|
weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
|
|
|
reference = [
|
|
|
(tensors1[i] * weights[0] + tensors2[i] * weights[1] + tensors3[i] * weights[2] + tensors4[i] * weights[3])
|
|
@@ -165,15 +168,13 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
|
|
|
for ref, our in zip(reference, averaged_tensors):
|
|
|
assert torch.allclose(ref, our, atol=1e-6)
|
|
|
|
|
|
- for averager in averagers:
|
|
|
- averager.shutdown()
|
|
|
- dht.shutdown()
|
|
|
+ for instance in averagers + dhts:
|
|
|
+ instance.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allreduce_compression():
|
|
|
"""this test ensures that compression works correctly when multiple tensors have different compression types"""
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
|
|
|
tensors1 = [torch.linspace(0, 500, 1000) ** 0.5, torch.randn(1000)]
|
|
|
tensors2 = [torch.linspace(300, 800, 1000) ** 0.5, torch.randn(1000)]
|
|
@@ -182,9 +183,10 @@ def test_allreduce_compression():
|
|
|
FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
|
|
|
|
|
|
for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
|
|
|
+ dhts = launch_dht_instances(2)
|
|
|
averager1 = hivemind.averaging.DecentralizedAverager(
|
|
|
[x.clone() for x in tensors1],
|
|
|
- dht=dht,
|
|
|
+ dht=dhts[0],
|
|
|
compression_type=compression_type_pair,
|
|
|
client_mode=True,
|
|
|
target_group_size=2,
|
|
@@ -193,7 +195,7 @@ def test_allreduce_compression():
|
|
|
)
|
|
|
averager2 = hivemind.averaging.DecentralizedAverager(
|
|
|
[x.clone() for x in tensors2],
|
|
|
- dht=dht,
|
|
|
+ dht=dhts[1],
|
|
|
compression_type=compression_type_pair,
|
|
|
target_group_size=2,
|
|
|
prefix="mygroup",
|
|
@@ -206,6 +208,9 @@ def test_allreduce_compression():
|
|
|
with averager1.get_tensors() as averaged_tensors:
|
|
|
results[compression_type_pair] = averaged_tensors
|
|
|
|
|
|
+ for instance in [averager1, averager2] + dhts:
|
|
|
+ instance.shutdown()
|
|
|
+
|
|
|
assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
|
|
|
assert torch.allclose(results[UINT8, FLOAT16][1], results[FLOAT16, FLOAT16][1])
|
|
|
assert torch.allclose(results[UINT8, UINT8][1], results[FLOAT16, UINT8][1])
|
|
@@ -236,17 +241,17 @@ def compute_mean_std(averagers, unbiased=True):
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allreduce_grid():
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
+ dhts = launch_dht_instances(8)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht,
|
|
|
+ dht=dht_instance,
|
|
|
target_group_size=2,
|
|
|
prefix="mygroup",
|
|
|
initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
|
|
|
start=True,
|
|
|
)
|
|
|
- for i in range(8)
|
|
|
+ for dht_instance in dhts
|
|
|
]
|
|
|
|
|
|
[means0], [stds0] = compute_mean_std(averagers)
|
|
@@ -266,25 +271,24 @@ def test_allreduce_grid():
|
|
|
else:
|
|
|
assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
|
|
|
|
|
|
- for averager in averagers:
|
|
|
+ for averager in averagers + dhts:
|
|
|
averager.shutdown()
|
|
|
- dht.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allgather():
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
+ dhts = launch_dht_instances(8)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
[torch.ones(1)],
|
|
|
- dht=dht,
|
|
|
+ dht=dht_instance,
|
|
|
target_group_size=4,
|
|
|
averaging_expiration=15,
|
|
|
prefix="mygroup",
|
|
|
initial_group_bits="000",
|
|
|
start=True,
|
|
|
)
|
|
|
- for _ in range(8)
|
|
|
+ for dht_instance in dhts
|
|
|
]
|
|
|
|
|
|
futures = []
|
|
@@ -304,9 +308,8 @@ def test_allgather():
|
|
|
for endpoint in gathered:
|
|
|
assert gathered[endpoint] == reference_metadata[endpoint]
|
|
|
|
|
|
- for averager in averagers:
|
|
|
+ for averager in averagers + dhts:
|
|
|
averager.shutdown()
|
|
|
- dht.shutdown()
|
|
|
|
|
|
|
|
|
def get_cost(vector_size, partitions, bandwidths):
|
|
@@ -354,11 +357,11 @@ def test_load_balancing():
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_too_few_peers():
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
+ dhts = launch_dht_instances(4)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht,
|
|
|
+ dht=dht_instance,
|
|
|
target_group_size=2,
|
|
|
averaging_expiration=1,
|
|
|
request_timeout=0.5,
|
|
@@ -366,24 +369,23 @@ def test_too_few_peers():
|
|
|
initial_group_bits=bin(i)[2:].rjust(3, "0"),
|
|
|
start=True,
|
|
|
)
|
|
|
- for i in range(4)
|
|
|
+ for i, dht_instance in enumerate(dhts)
|
|
|
]
|
|
|
step_futures = [averager.step(wait=False) for averager in averagers]
|
|
|
for future in step_futures:
|
|
|
assert len(future.result()) == 2
|
|
|
|
|
|
- for averager in averagers:
|
|
|
+ for averager in averagers + dhts:
|
|
|
averager.shutdown()
|
|
|
- dht.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_overcrowded(num_peers=16):
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
+ dhts = launch_dht_instances(num_peers)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht,
|
|
|
+ dht=dht_instance,
|
|
|
target_group_size=2,
|
|
|
averaging_expiration=1,
|
|
|
request_timeout=0.5,
|
|
@@ -391,15 +393,14 @@ def test_overcrowded(num_peers=16):
|
|
|
initial_group_bits="",
|
|
|
start=True,
|
|
|
)
|
|
|
- for _ in range(num_peers)
|
|
|
+ for dht_instance in dhts
|
|
|
]
|
|
|
for t in range(5):
|
|
|
step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
|
|
|
assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
|
|
|
|
|
|
- for averager in averagers:
|
|
|
+ for averager in averagers + dhts:
|
|
|
averager.shutdown()
|
|
|
- dht.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
@@ -418,22 +419,19 @@ def test_load_state_from_peers():
|
|
|
num_calls += 1
|
|
|
return super_metadata, super_tensors
|
|
|
|
|
|
- dht_root = hivemind.DHT(start=True)
|
|
|
- initial_peers = dht_root.get_visible_maddrs()
|
|
|
- dht1 = hivemind.DHT(initial_peers=initial_peers, start=True)
|
|
|
+ dhts = launch_dht_instances(2)
|
|
|
averager1 = TestAverager(
|
|
|
[torch.randn(3), torch.rand(5)],
|
|
|
- dht=dht1,
|
|
|
+ dht=dhts[0],
|
|
|
start=True,
|
|
|
prefix="demo-run",
|
|
|
target_group_size=2,
|
|
|
)
|
|
|
|
|
|
- dht2 = hivemind.DHT(initial_peers=initial_peers, start=True)
|
|
|
- dht2.get("demo-run.all_averagers")
|
|
|
+ dhts[1].get("demo-run.all_averagers")
|
|
|
averager2 = TestAverager(
|
|
|
[torch.randn(3), torch.rand(5)],
|
|
|
- dht=dht2,
|
|
|
+ dht=dhts[1],
|
|
|
start=True,
|
|
|
prefix="demo-run",
|
|
|
target_group_size=2,
|
|
@@ -462,6 +460,9 @@ def test_load_state_from_peers():
|
|
|
assert num_calls == 3
|
|
|
assert got_metadata == super_metadata
|
|
|
|
|
|
+ for instance in [averager1, averager2] + dhts:
|
|
|
+ instance.shutdown()
|
|
|
+
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_getset_bits():
|
|
@@ -481,9 +482,8 @@ def test_getset_bits():
|
|
|
def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
- dht = hivemind.DHT(start=True)
|
|
|
+ dhts = launch_dht_instances(2)
|
|
|
common_kwargs = {
|
|
|
- "dht": dht,
|
|
|
"start": True,
|
|
|
"prefix": "demo-run",
|
|
|
"target_group_size": 2,
|
|
@@ -492,13 +492,23 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
x1 = torch.randn(n_dims, requires_grad=True)
|
|
|
opt1 = torch.optim.Adam([x1], lr=0.05)
|
|
|
averager1 = hivemind.averaging.TrainingAverager(
|
|
|
- opt1, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
|
|
|
+ opt1,
|
|
|
+ average_gradients=True,
|
|
|
+ average_parameters=True,
|
|
|
+ average_opt_statistics=["exp_avg_sq"],
|
|
|
+ dht=dhts[0],
|
|
|
+ **common_kwargs
|
|
|
)
|
|
|
|
|
|
x2 = torch.randn(n_dims, requires_grad=True)
|
|
|
opt2 = torch.optim.Adam([x2], lr=0.05)
|
|
|
averager2 = hivemind.averaging.TrainingAverager(
|
|
|
- opt2, average_gradients=True, average_parameters=True, average_opt_statistics=["exp_avg_sq"], **common_kwargs
|
|
|
+ opt2,
|
|
|
+ average_gradients=True,
|
|
|
+ average_parameters=True,
|
|
|
+ average_opt_statistics=["exp_avg_sq"],
|
|
|
+ dht=dhts[1],
|
|
|
+ **common_kwargs
|
|
|
)
|
|
|
a = torch.ones(n_dims)
|
|
|
|
|
@@ -528,6 +538,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
|
|
|
assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
|
|
|
|
|
|
- averager1.shutdown()
|
|
|
- averager2.shutdown()
|
|
|
- dht.shutdown()
|
|
|
+ for instance in [averager1, averager2] + dhts:
|
|
|
+ instance.shutdown()
|