|
@@ -474,8 +474,8 @@ def test_lr_scheduler(n_steps: int = 100, n_dims: int = 16, time_to_wait: int =
|
|
|
**sgd_kwargs
|
|
|
)
|
|
|
assert sgd3.local_epoch == sgd2.local_epoch
|
|
|
- assert sgd3.local_epoch == sgd1.local_epoch
|
|
|
- assert all([x['lr'] == y['lr'] for x, y in zip(sgd1.opt.param_groups, sgd2.opt.param_groups)])
|
|
|
+ assert all([x['lr'] == y['lr'] for x, y in zip(sgd2.opt.param_groups, sgd3.opt.param_groups)])
|
|
|
|
|
|
sgd1.shutdown()
|
|
|
sgd2.shutdown()
|
|
|
+ sgd3.shutdown()
|