|
@@ -377,7 +377,8 @@ def test_optimizer(
|
|
assert isinstance(optimizer, Optimizer)
|
|
assert isinstance(optimizer, Optimizer)
|
|
assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
|
|
assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
|
|
expected_samples_accumulated = target_batch_size * total_epochs
|
|
expected_samples_accumulated = target_batch_size * total_epochs
|
|
- assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
|
|
|
|
|
|
+ assert expected_samples_accumulated <= total_samples_accumulated.value <= 2 * expected_samples_accumulated
|
|
|
|
+ assert 4 / 0.3 * 0.9 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.1
|
|
|
|
|
|
assert not optimizer.state_averager.is_alive()
|
|
assert not optimizer.state_averager.is_alive()
|
|
assert not optimizer.grad_averager.is_alive()
|
|
assert not optimizer.grad_averager.is_alive()
|