|
@@ -303,8 +303,9 @@ def test_optimizer(
|
|
|
features = torch.randn(100, 5)
|
|
|
targets = features @ torch.randn(5, 1)
|
|
|
optimizer = None
|
|
|
+ total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
|
|
|
|
|
|
- def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
|
|
|
+ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
|
|
|
nonlocal optimizer
|
|
|
model = nn.Linear(5, 1)
|
|
|
|
|
@@ -327,7 +328,7 @@ def test_optimizer(
|
|
|
delay_optimizer_step=delay_optimizer_step,
|
|
|
average_state_every=average_state_every,
|
|
|
client_mode=client_mode,
|
|
|
- verbose=verbose,
|
|
|
+ verbose=False,
|
|
|
)
|
|
|
optimizer.load_state_from_peers()
|
|
|
|
|
@@ -342,6 +343,8 @@ def test_optimizer(
|
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
+ total_samples_accumulated.value += batch_size
|
|
|
+
|
|
|
if not reuse_grad_buffers:
|
|
|
optimizer.zero_grad()
|
|
|
|
|
@@ -362,7 +365,6 @@ def test_optimizer(
|
|
|
batch_size=4 + index,
|
|
|
batch_time=0.3 + 0.2 * index,
|
|
|
client_mode=(index >= num_peers - num_clients),
|
|
|
- verbose=(index == 0),
|
|
|
),
|
|
|
)
|
|
|
)
|
|
@@ -374,4 +376,11 @@ def test_optimizer(
|
|
|
peer.join()
|
|
|
|
|
|
assert isinstance(optimizer, Optimizer)
|
|
|
- assert optimizer.local_epoch == total_epochs
|
|
|
+ assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
|
|
|
+ expected_samples_accumulated = target_batch_size * total_epochs
|
|
|
+ assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
|
|
|
+
|
|
|
+ assert not optimizer.state_averager.is_alive()
|
|
|
+ assert not optimizer.grad_averager.is_alive()
|
|
|
+ assert not optimizer.tracker.is_alive()
|
|
|
+ assert optimizer.scheduled_round is None or optimizer.scheduled_round.done()
|