|
@@ -2,6 +2,7 @@ import ctypes
|
|
|
import multiprocessing as mp
|
|
|
import time
|
|
|
from functools import partial
|
|
|
+import random
|
|
|
|
|
|
import numpy as np
|
|
|
import pytest
|
|
@@ -14,6 +15,7 @@ from hivemind.averaging.control import AveragingStage
|
|
|
from hivemind.optim.experimental.grad_averager import GradientAverager
|
|
|
from hivemind.optim.experimental.progress_tracker import ProgressTracker
|
|
|
from hivemind.optim.experimental.state_averager import TrainingStateAverager
|
|
|
+from hivemind.optim.experimental.optimizer import Optimizer
|
|
|
from hivemind.utils.crypto import RSAPrivateKey
|
|
|
|
|
|
|
|
@@ -283,3 +285,86 @@ def test_progress_tracker():
|
|
|
assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
|
|
|
assert emas[1] < emas[2] < emas[3] < emas[4]
|
|
|
assert tracker.performance_ema.samples_per_second < 1e-9
|
|
|
+
|
|
|
+
|
|
|
+def test_optimizer(num_peers: int = 2, num_clients: int = 1, target_batch_size: int = 64, total_epochs: int = 3,
|
|
|
+ reuse_grad_buffers: bool = True, delay_grad_averaging: bool = True,
|
|
|
+ delay_optimizer_step: bool = True, average_state_every: int = 1):
|
|
|
+ dht = hivemind.DHT(start=True)
|
|
|
+
|
|
|
+ features = torch.randn(100, 5)
|
|
|
+ targets = features @ torch.randn(5, 1)
|
|
|
+ optimizer = None
|
|
|
+
|
|
|
+ def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
|
|
|
+ nonlocal optimizer
|
|
|
+ model = nn.Linear(5, 1)
|
|
|
+
|
|
|
+ assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
|
|
|
+
|
|
|
+ optimizer = Optimizer(
|
|
|
+ prefix="test_run",
|
|
|
+ target_batch_size=target_batch_size,
|
|
|
+ batch_size_per_step=batch_size,
|
|
|
+ params=model.parameters(),
|
|
|
+ optimizer=partial(torch.optim.SGD, lr=0.1),
|
|
|
+ scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
|
|
|
+ dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
|
|
|
+ tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
|
|
|
+ averager_opts=dict(min_matchmaking_time=1.0, request_timeout=0.5),
|
|
|
+ matchmaking_time=1.0,
|
|
|
+ averaging_timeout=5.0,
|
|
|
+ reuse_grad_buffers=reuse_grad_buffers,
|
|
|
+ delay_grad_averaging=delay_grad_averaging,
|
|
|
+ delay_optimizer_step=delay_optimizer_step,
|
|
|
+ average_state_every=average_state_every,
|
|
|
+ client_mode=client_mode,
|
|
|
+ verbose=verbose,
|
|
|
+ )
|
|
|
+ optimizer.load_state_from_peers()
|
|
|
+
|
|
|
+ prev_time = time.perf_counter()
|
|
|
+
|
|
|
+ while optimizer.local_epoch < total_epochs:
|
|
|
+ time.sleep(max(0.0, prev_time + random.gauss(batch_time, 0.1) - time.perf_counter()))
|
|
|
+ batch = torch.randint(0, len(features), (batch_size,))
|
|
|
+
|
|
|
+ loss = F.mse_loss(model(features[batch]), targets[batch])
|
|
|
+ loss.backward()
|
|
|
+
|
|
|
+ optimizer.step()
|
|
|
+
|
|
|
+ if not reuse_grad_buffers:
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ prev_time = time.perf_counter()
|
|
|
+
|
|
|
+ time.sleep(1.0)
|
|
|
+ optimizer.shutdown()
|
|
|
+ return optimizer
|
|
|
+
|
|
|
+ peers = []
|
|
|
+
|
|
|
+ for index in range(num_peers):
|
|
|
+ peers.append(
|
|
|
+ mp.Process(
|
|
|
+ target=run_trainer,
|
|
|
+ name=f"trainer-{index}",
|
|
|
+ kwargs=dict(
|
|
|
+ batch_size=4 + index,
|
|
|
+ batch_time=0.3 + 0.2 * index,
|
|
|
+ client_mode=(index >= num_peers - num_clients),
|
|
|
+ verbose=(index == 0),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ for peer in peers[1:]:
|
|
|
+ peer.start()
|
|
|
+ peers[0].run()
|
|
|
+ for peer in peers[1:]:
|
|
|
+ peer.join()
|
|
|
+
|
|
|
+ assert isinstance(optimizer, Optimizer)
|
|
|
+ assert optimizer.local_epoch == total_epochs
|
|
|
+
|