浏览代码

fix tests

Artem Chumachenko 3 年之前
父节点
当前提交
54223ccf93
共有 1 个文件被更改,包括 2 次插入2 次删除
  1. 2 2
      tests/test_optimizer.py

+ 2 - 2
tests/test_optimizer.py

@@ -290,7 +290,7 @@ def test_progress_tracker():
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "grad_averager",
-    [(GradientAverager.get_factory(),), (PowerEFGradientAverager.get_factory(averager_rank=1),)],
+    [GradientAverager.get_factory(), PowerEFGradientAverager.get_factory(averager_rank=1)],
 )
 def test_optimizer(
     grad_averager: GradientAveragerFactory,
@@ -337,7 +337,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
-            grad_averager=GradientAverager,
+            grad_averager=grad_averager,
             verbose=False,
         )
         optimizer.load_state_from_peers()