|
@@ -74,11 +74,11 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
for i in range(len(tensors1))
|
|
|
]
|
|
|
|
|
|
- dhts = launch_dht_instances(len(peer_tensors))
|
|
|
+ dht_instances = launch_dht_instances(len(peer_tensors))
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
tensors,
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=4,
|
|
|
averaging_expiration=15,
|
|
|
prefix="mygroup",
|
|
@@ -86,7 +86,7 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
auxiliary=mode == AveragingMode.AUX,
|
|
|
start=True,
|
|
|
)
|
|
|
- for tensors, dht_instance, mode in zip(peer_tensors, dhts, modes)
|
|
|
+ for tensors, dht, mode in zip(peer_tensors, dht_instances, modes)
|
|
|
]
|
|
|
|
|
|
futures = []
|
|
@@ -103,7 +103,7 @@ def _test_allreduce_once(n_clients, n_aux):
|
|
|
for ref, our in zip(reference, averaged_tensors):
|
|
|
assert torch.allclose(ref, our, atol=1e-6)
|
|
|
|
|
|
- for instance in averagers + dhts:
|
|
|
+ for instance in averagers + dht_instances:
|
|
|
instance.shutdown()
|
|
|
|
|
|
|
|
@@ -131,18 +131,18 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
|
|
|
tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
|
|
|
tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
|
|
|
|
|
|
- dhts = launch_dht_instances(4)
|
|
|
+ dht_instances = launch_dht_instances(4)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
tensors,
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=4,
|
|
|
averaging_expiration=15,
|
|
|
prefix="mygroup",
|
|
|
client_mode=client_mode,
|
|
|
start=True,
|
|
|
)
|
|
|
- for tensors, dht_instance, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dhts, client_modes)
|
|
|
+ for tensors, dht, client_mode in zip([tensors1, tensors2, tensors3, tensors4], dht_instances, client_modes)
|
|
|
]
|
|
|
|
|
|
weights = list(map(float, np.random.rand(len(averagers)) * 10 + 0.01))
|
|
@@ -163,7 +163,7 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
|
|
|
for ref, our in zip(reference, averaged_tensors):
|
|
|
assert torch.allclose(ref, our, atol=1e-6)
|
|
|
|
|
|
- for instance in averagers + dhts:
|
|
|
+ for instance in averagers + dht_instances:
|
|
|
instance.shutdown()
|
|
|
|
|
|
|
|
@@ -178,10 +178,10 @@ def test_allreduce_compression():
|
|
|
FLOAT16, UINT8 = CompressionType.FLOAT16, CompressionType.UNIFORM_8BIT
|
|
|
|
|
|
for compression_type_pair in [(FLOAT16, FLOAT16), (FLOAT16, UINT8), (UINT8, FLOAT16), (UINT8, UINT8)]:
|
|
|
- dhts = launch_dht_instances(2)
|
|
|
+ dht_instances = launch_dht_instances(2)
|
|
|
averager1 = hivemind.averaging.DecentralizedAverager(
|
|
|
[x.clone() for x in tensors1],
|
|
|
- dht=dhts[0],
|
|
|
+ dht=dht_instances[0],
|
|
|
compression_type=compression_type_pair,
|
|
|
client_mode=True,
|
|
|
target_group_size=2,
|
|
@@ -190,7 +190,7 @@ def test_allreduce_compression():
|
|
|
)
|
|
|
averager2 = hivemind.averaging.DecentralizedAverager(
|
|
|
[x.clone() for x in tensors2],
|
|
|
- dht=dhts[1],
|
|
|
+ dht=dht_instances[1],
|
|
|
compression_type=compression_type_pair,
|
|
|
target_group_size=2,
|
|
|
prefix="mygroup",
|
|
@@ -203,7 +203,7 @@ def test_allreduce_compression():
|
|
|
with averager1.get_tensors() as averaged_tensors:
|
|
|
results[compression_type_pair] = averaged_tensors
|
|
|
|
|
|
- for instance in [averager1, averager2] + dhts:
|
|
|
+ for instance in [averager1, averager2] + dht_instances:
|
|
|
instance.shutdown()
|
|
|
|
|
|
assert torch.allclose(results[UINT8, FLOAT16][0], results[UINT8, UINT8][0])
|
|
@@ -236,17 +236,17 @@ def compute_mean_std(averagers, unbiased=True):
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allreduce_grid():
|
|
|
- dhts = launch_dht_instances(8)
|
|
|
+ dht_instances = launch_dht_instances(8)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=2,
|
|
|
prefix="mygroup",
|
|
|
initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
|
|
|
start=True,
|
|
|
)
|
|
|
- for i, dht_instance in enumerate(dhts)
|
|
|
+ for i, dht in enumerate(dht_instances)
|
|
|
]
|
|
|
|
|
|
[means0], [stds0] = compute_mean_std(averagers)
|
|
@@ -266,24 +266,24 @@ def test_allreduce_grid():
|
|
|
else:
|
|
|
assert torch.allclose(stds, torch.zeros_like(stds), atol=1e-6, rtol=0)
|
|
|
|
|
|
- for averager in averagers + dhts:
|
|
|
+ for averager in averagers + dht_instances:
|
|
|
averager.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_allgather():
|
|
|
- dhts = launch_dht_instances(8)
|
|
|
+ dht_instances = launch_dht_instances(8)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
[torch.ones(1)],
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=4,
|
|
|
averaging_expiration=15,
|
|
|
prefix="mygroup",
|
|
|
initial_group_bits="000",
|
|
|
start=True,
|
|
|
)
|
|
|
- for dht_instance in dhts
|
|
|
+ for dht in dht_instances
|
|
|
]
|
|
|
|
|
|
futures = []
|
|
@@ -307,7 +307,7 @@ def test_allgather():
|
|
|
for endpoint in gathered:
|
|
|
assert gathered[endpoint] == reference_metadata[endpoint]
|
|
|
|
|
|
- for averager in averagers + dhts:
|
|
|
+ for averager in averagers + dht_instances:
|
|
|
averager.shutdown()
|
|
|
|
|
|
|
|
@@ -356,11 +356,11 @@ def test_load_balancing():
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_too_few_peers():
|
|
|
- dhts = launch_dht_instances(4)
|
|
|
+ dht_instances = launch_dht_instances(4)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=2,
|
|
|
averaging_expiration=1,
|
|
|
request_timeout=0.5,
|
|
@@ -368,23 +368,23 @@ def test_too_few_peers():
|
|
|
initial_group_bits=bin(i)[2:].rjust(3, "0"),
|
|
|
start=True,
|
|
|
)
|
|
|
- for i, dht_instance in enumerate(dhts)
|
|
|
+ for i, dht in enumerate(dht_instances)
|
|
|
]
|
|
|
step_futures = [averager.step(wait=False) for averager in averagers]
|
|
|
for future in step_futures:
|
|
|
assert len(future.result()) == 2
|
|
|
|
|
|
- for averager in averagers + dhts:
|
|
|
+ for averager in averagers + dht_instances:
|
|
|
averager.shutdown()
|
|
|
|
|
|
|
|
|
@pytest.mark.forked
|
|
|
def test_overcrowded(num_peers=16):
|
|
|
- dhts = launch_dht_instances(num_peers)
|
|
|
+ dht_instances = launch_dht_instances(num_peers)
|
|
|
averagers = [
|
|
|
hivemind.averaging.DecentralizedAverager(
|
|
|
averaged_tensors=[torch.randn(3)],
|
|
|
- dht=dht_instance,
|
|
|
+ dht=dht,
|
|
|
target_group_size=2,
|
|
|
averaging_expiration=1,
|
|
|
request_timeout=0.5,
|
|
@@ -392,13 +392,13 @@ def test_overcrowded(num_peers=16):
|
|
|
initial_group_bits="",
|
|
|
start=True,
|
|
|
)
|
|
|
- for dht_instance in dhts
|
|
|
+ for dht in dht_instances
|
|
|
]
|
|
|
for t in range(5):
|
|
|
step_futures = [averager.step(wait=False, timeout=5) for averager in averagers]
|
|
|
assert sum(len(future.result() or []) == 2 for future in step_futures) >= len(averagers) - 1
|
|
|
|
|
|
- for averager in averagers + dhts:
|
|
|
+ for averager in averagers + dht_instances:
|
|
|
averager.shutdown()
|
|
|
|
|
|
|
|
@@ -418,19 +418,19 @@ def test_load_state_from_peers():
|
|
|
num_calls += 1
|
|
|
return super_metadata, super_tensors
|
|
|
|
|
|
- dhts = launch_dht_instances(2)
|
|
|
+ dht_instances = launch_dht_instances(2)
|
|
|
averager1 = TestAverager(
|
|
|
[torch.randn(3), torch.rand(5)],
|
|
|
- dht=dhts[0],
|
|
|
+ dht=dht_instances[0],
|
|
|
start=True,
|
|
|
prefix="demo-run",
|
|
|
target_group_size=2,
|
|
|
)
|
|
|
|
|
|
- dhts[1].get("demo-run.all_averagers")
|
|
|
+ dht_instances[1].get("demo-run.all_averagers")
|
|
|
averager2 = TestAverager(
|
|
|
[torch.randn(3), torch.rand(5)],
|
|
|
- dht=dhts[1],
|
|
|
+ dht=dht_instances[1],
|
|
|
start=True,
|
|
|
prefix="demo-run",
|
|
|
target_group_size=2,
|
|
@@ -459,7 +459,7 @@ def test_load_state_from_peers():
|
|
|
assert num_calls == 3
|
|
|
assert got_metadata == super_metadata
|
|
|
|
|
|
- for instance in [averager1, averager2] + dhts:
|
|
|
+ for instance in [averager1, averager2] + dht_instances:
|
|
|
instance.shutdown()
|
|
|
|
|
|
|
|
@@ -481,7 +481,7 @@ def test_getset_bits():
|
|
|
def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
- dhts = launch_dht_instances(2)
|
|
|
+ dht_instances = launch_dht_instances(2)
|
|
|
common_kwargs = {
|
|
|
"start": True,
|
|
|
"prefix": "demo-run",
|
|
@@ -495,7 +495,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
average_gradients=True,
|
|
|
average_parameters=True,
|
|
|
average_opt_statistics=["exp_avg_sq"],
|
|
|
- dht=dhts[0],
|
|
|
+ dht=dht_instances[0],
|
|
|
**common_kwargs
|
|
|
)
|
|
|
|
|
@@ -506,7 +506,7 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
average_gradients=True,
|
|
|
average_parameters=True,
|
|
|
average_opt_statistics=["exp_avg_sq"],
|
|
|
- dht=dhts[1],
|
|
|
+ dht=dht_instances[1],
|
|
|
**common_kwargs
|
|
|
)
|
|
|
a = torch.ones(n_dims)
|
|
@@ -537,5 +537,5 @@ def test_training_averager(n_steps: int = 10, n_dims: int = 16):
|
|
|
assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
|
|
|
assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
|
|
|
|
|
|
- for instance in [averager1, averager2] + dhts:
|
|
|
+ for instance in [averager1, averager2] + dht_instances:
|
|
|
instance.shutdown()
|