|
@@ -83,6 +83,28 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
|
|
assert not torch.allclose(model2.w.grad, ref_average)
|
|
assert not torch.allclose(model2.w.grad, ref_average)
|
|
|
|
|
|
|
|
|
|
|
|
+@pytest.mark.forked
|
|
|
|
+@pytest.mark.parametrize(
|
|
|
|
+ "grad_averager_factory",
|
|
|
|
+ [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
|
|
|
|
+)
|
|
|
|
+def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
|
|
|
|
+ parameter_shape = (5, 5)
|
|
|
|
+ model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
|
|
|
|
+ dht = hivemind.DHT(start=True)
|
|
|
|
+
|
|
|
|
+ with pytest.raises(ValueError):
|
|
|
|
+ grad_averager_factory(
|
|
|
|
+ model.parameters(),
|
|
|
|
+ dht=dht,
|
|
|
|
+ prefix="test_fail",
|
|
|
|
+ target_group_size=2,
|
|
|
|
+ reuse_grad_buffers=False,
|
|
|
|
+ start=True,
|
|
|
|
+ averaged_grads=[torch.zeros(parameter_shape + (1,))],
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
@pytest.mark.forked
|
|
@pytest.mark.forked
|
|
@pytest.mark.parametrize(
|
|
@pytest.mark.parametrize(
|
|
"offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
|
|
"offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
|