Aleksandr Borzunov 3 anni fa
parent
commit
dfec847069
1 ha cambiato i file con 13 aggiunte e 4 eliminazioni
  1. 13 4
      tests/test_optimizer.py

+ 13 - 4
tests/test_optimizer.py

@@ -303,8 +303,9 @@ def test_optimizer(
     features = torch.randn(100, 5)
     targets = features @ torch.randn(5, 1)
     optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
 
-    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
         nonlocal optimizer
         model = nn.Linear(5, 1)
 
@@ -327,7 +328,7 @@ def test_optimizer(
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             client_mode=client_mode,
-            verbose=verbose,
+            verbose=False,
         )
         optimizer.load_state_from_peers()
 
@@ -342,6 +343,8 @@ def test_optimizer(
 
             optimizer.step()
 
+            total_samples_accumulated.value += batch_size
+
             if not reuse_grad_buffers:
                 optimizer.zero_grad()
 
@@ -362,7 +365,6 @@ def test_optimizer(
                     batch_size=4 + index,
                     batch_time=0.3 + 0.2 * index,
                     client_mode=(index >= num_peers - num_clients),
-                    verbose=(index == 0),
                 ),
             )
         )
@@ -374,4 +376,11 @@ def test_optimizer(
         peer.join()
 
     assert isinstance(optimizer, Optimizer)
-    assert optimizer.local_epoch == total_epochs
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_round is None or optimizer.scheduled_round.done()