|
@@ -245,7 +245,7 @@ def test_allreduce_grid():
|
|
|
initial_group_bits=bin(i // 2)[2:].rjust(2, "0"),
|
|
|
start=True,
|
|
|
)
|
|
|
- for dht_instance in dhts
|
|
|
+ for i, dht_instance in enumerate(dhts)
|
|
|
]
|
|
|
|
|
|
[means0], [stds0] = compute_mean_std(averagers)
|
|
@@ -289,7 +289,10 @@ def test_allgather():
|
|
|
for i, averager in enumerate(averagers):
|
|
|
futures.append(averager.step(wait=False, gather=dict(batch_size=123 + i, foo="bar")))
|
|
|
|
|
|
- assert len(set(repr(sorted(future.result())) for future in futures)) == 2
|
|
|
+ gathered_data = [future.result() for future in futures]
|
|
|
+ gathered_data_reprs = [repr(sorted({peer_id.to_base58(): data for peer_id, data in result.items()}))
|
|
|
+ for result in gathered_data]
|
|
|
+ assert len(set(gathered_data_reprs)) == 2
|
|
|
|
|
|
reference_metadata = {
|
|
|
averager.endpoint: dict(batch_size=123 + i, foo="bar") for i, averager in enumerate(averagers)
|