Artem Chumachenko 3 vuotta sitten
vanhempi
commit
54223ccf93
1 muutettua tiedostoa jossa 2 lisäystä ja 2 poistoa
  1. 2 2
      tests/test_optimizer.py

+ 2 - 2
tests/test_optimizer.py

@@ -290,7 +290,7 @@ def test_progress_tracker():
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "grad_averager",
-    [(GradientAverager.get_factory(),), (PowerEFGradientAverager.get_factory(averager_rank=1),)],
+    [GradientAverager.get_factory(), PowerEFGradientAverager.get_factory(averager_rank=1)],
 )
 def test_optimizer(
     grad_averager: GradientAveragerFactory,
@@ -337,7 +337,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
-            grad_averager=GradientAverager,
+            grad_averager=grad_averager,
             verbose=False,
         )
         optimizer.load_state_from_peers()