|
@@ -290,7 +290,7 @@ def test_progress_tracker():
|
|
|
@pytest.mark.forked
|
|
|
@pytest.mark.parametrize(
|
|
|
"grad_averager",
|
|
|
- [(GradientAverager.get_factory(),), (PowerEFGradientAverager.get_factory(averager_rank=1),)],
|
|
|
+ [GradientAverager.get_factory(), PowerEFGradientAverager.get_factory(averager_rank=1)],
|
|
|
)
|
|
|
def test_optimizer(
|
|
|
grad_averager: GradientAveragerFactory,
|
|
@@ -337,7 +337,7 @@ def test_optimizer(
|
|
|
delay_optimizer_step=delay_optimizer_step,
|
|
|
average_state_every=average_state_every,
|
|
|
client_mode=client_mode,
|
|
|
- grad_averager=GradientAverager,
|
|
|
+ grad_averager=grad_averager,
|
|
|
verbose=False,
|
|
|
)
|
|
|
optimizer.load_state_from_peers()
|