Browse Source

reduce test duration

justheuristic 3 years ago
parent
commit
8f7810e71d

+ 1 - 1
hivemind/optim/experimental/optimizer.py

@@ -456,5 +456,5 @@ class Optimizer(torch.optim.Optimizer):
         logger.debug(f"{self.__class__.__name__} is shut down.")
 
     def __del__(self):
-        if self.is_alive() and self._parent_pid == os.getpid():
+        if self._parent_pid == os.getpid() and self.is_alive():
             self.shutdown()

+ 1 - 1
hivemind/optim/experimental/state_averager.py

@@ -199,7 +199,7 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.log(
                 self.status_loglevel,
                 "Initializing optimizer manually since it has no tensors in state dict. "
-                "To override this, please provide initialize_optimizer=False",
+                "To override this, provide initialize_optimizer=False",
             )
 
         if initialize_optimizer:

+ 85 - 0
tests/test_optimizer.py

@@ -2,6 +2,7 @@ import ctypes
 import multiprocessing as mp
 import time
 from functools import partial
+import random
 
 import numpy as np
 import pytest
@@ -14,6 +15,7 @@ from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
+from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.utils.crypto import RSAPrivateKey
 
 
@@ -283,3 +285,86 @@ def test_progress_tracker():
         assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
     assert emas[1] < emas[2] < emas[3] < emas[4]
     assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size: int = 64, total_epochs: int = 3,
+                   reuse_grad_buffers: bool = True, delay_grad_averaging: bool = True,
+                   delay_optimizer_step: bool = True, average_state_every: int = 1):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            prefix="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(min_matchmaking_time=1.0, request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, 0.1) - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                    verbose=(index == 0),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == total_epochs
+