xtinkt il y a 4 ans
Parent
commit
0ac82c1972
1 fichiers modifiés avec 2 ajouts et 2 suppressions
  1. 2 2
      tests/test_averaging.py

+ 2 - 2
tests/test_averaging.py

@@ -474,8 +474,8 @@ def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int =
         **sgd_kwargs
     )
     assert sgd3.local_epoch == sgd2.local_epoch
-    assert sgd3.local_epoch == sgd1.local_epoch
-    assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
+    assert all([x['lr'] == y['lr'] for x, y in zip(sgd2.opt.param_groups, sgd3.opt.param_groups)])
 
     sgd1.shutdown()
     sgd2.shutdown()
+    sgd3.shutdown()