|
@@ -33,7 +33,7 @@ async def test_partitioning():
|
|
|
|
|
|
# note: this test does _not_ use parameterization to reuse sampled tensors
|
|
|
for num_tensors in 1, 3, 5:
|
|
|
- for part_size_bytes in 31337, 2 ** 20, 10 ** 10:
|
|
|
+ for part_size_bytes in 31337, 2**20, 10**10:
|
|
|
for weights in [(1, 1), (0.333, 0.1667, 0.5003), (1.0, 0.0), [0.0, 0.4, 0.6, 0.0]]:
|
|
|
tensors = random.choices(all_tensors, k=num_tensors)
|
|
|
partition = TensorPartContainer(tensors, weights, part_size_bytes=part_size_bytes)
|
|
@@ -157,16 +157,16 @@ NODE, CLIENT, AUX = AveragingMode.NODE, AveragingMode.CLIENT, AveragingMode.AUX
|
|
|
@pytest.mark.parametrize(
|
|
|
"peer_modes, averaging_weights, peer_fractions, part_size_bytes",
|
|
|
[
|
|
|
- ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2 ** 20),
|
|
|
- ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2 ** 20),
|
|
|
- ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
|
|
|
- ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2 ** 20),
|
|
|
- ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2 ** 20),
|
|
|
- ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2 ** 20),
|
|
|
- ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2 ** 20),
|
|
|
+ ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 1, 1, 1), 2**20),
|
|
|
+ ((NODE, NODE, NODE, NODE), (0.1, 0.2, 0.3, 0.4), (1, 1, 1, 1), 2**20),
|
|
|
+ ((NODE, NODE, NODE, NODE), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
|
|
|
+ ((NODE, NODE, NODE, CLIENT), (1, 1, 1, 1), (1, 2, 3, 0), 2**20),
|
|
|
+ ((NODE, NODE, NODE, AUX), (1, 1, 1, 0), (1, 2, 3, 4), 2**20),
|
|
|
+ ((NODE, NODE, NODE, NODE), (0.15, 0.0, 0.35, 0.45), (1, 1, 1, 1), 2**20),
|
|
|
+ ((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 2**20),
|
|
|
((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 256),
|
|
|
((NODE, AUX, NODE, CLIENT), (0.15, 0.0, 0.35, 0.45), (150, 200, 67, 0), 19),
|
|
|
- ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2 ** 20),
|
|
|
+ ((AUX, AUX, AUX, AUX), (0.0, 0.0, 0.0, 0.0), (1, 2, 3, 4), 2**20),
|
|
|
],
|
|
|
)
|
|
|
@pytest.mark.forked
|