justheuristic 3 سال پیش
والد
کامیت
500715b19e
1فایلهای تغییر یافته به همراه124 افزوده شده و 0 حذف شده
  1. 124 0
      tests/test_optimizer.py

+ 124 - 0
tests/test_optimizer.py

@@ -1,17 +1,23 @@
 import ctypes
 import multiprocessing as mp
+import random
 import time
+from dataclasses import dataclass
 from functools import partial
+from typing import Callable
 
 import numpy as np
 import pytest
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+import torchvision
+from datasets import Dataset
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
@@ -279,3 +285,121 @@ def test_progress_tracker():
         assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
     assert emas[1] < emas[2] < emas[3] < emas[4]
     assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    prefix: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 128
+    reuse_grad_buffers: bool = True
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device = "cuda:0" if torch.cuda.is_available() else "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def _run_training_with_swarm(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = np.prod(train_dataset.data[0].shape)
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            prefix=args.prefix,
+            target_batch_size=args.target_batch_size,
+            param_groups=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+            loss = F.cross_entropy(model(X_train[batch]), y_train[batch])
+            loss.backward()
+
+            optimizer.step(batch_size=batch_size)
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+
+@pytest.mark.forked
+def test_training_short():
+    _run_training_with_swarm(TrainingArguments(
+        num_peers=3, num_clients=1, target_batch_size=32, max_epoch=3, lr_step_size=1,
+    ))