Browse Source

Merge branch 'master' into rfc_optimizer

justheuristic 3 năm trước cách đây
mục cha
commit
b755f8a8f5
42 tập tin đã thay đổi với 2883 bổ sung507 xóa
  1. 1 1
      .github/workflows/check-style.yml
  2. 35 0
      .github/workflows/run-benchmarks.yml
  3. 1 1
      .github/workflows/run-tests.yml
  4. 1 1
      README.md
  5. 178 77
      benchmarks/benchmark_dht.py
  6. 163 0
      benchmarks/benchmark_optimizer.py
  7. 26 4
      docs/modules/optim.rst
  8. 34 31
      docs/user/quickstart.md
  9. 2 0
      hivemind/__init__.py
  10. 139 75
      hivemind/averaging/allreduce.py
  11. 105 45
      hivemind/averaging/averager.py
  12. 24 7
      hivemind/averaging/control.py
  13. 21 4
      hivemind/averaging/key_manager.py
  14. 1 1
      hivemind/averaging/load_balancing.py
  15. 25 12
      hivemind/averaging/matchmaking.py
  16. 57 19
      hivemind/averaging/partition.py
  17. 3 0
      hivemind/compression/base.py
  18. 1 1
      hivemind/dht/dht.py
  19. 1 1
      hivemind/dht/node.py
  20. 1 1
      hivemind/dht/routing.py
  21. 1 1
      hivemind/hivemind_cli/run_server.py
  22. 2 2
      hivemind/moe/client/moe.py
  23. 2 2
      hivemind/moe/server/__init__.py
  24. 2 1
      hivemind/optim/__init__.py
  25. 15 7
      hivemind/optim/collaborative.py
  26. 29 22
      hivemind/optim/experimental/grad_averager.py
  27. 758 0
      hivemind/optim/experimental/optimizer.py
  28. 358 0
      hivemind/optim/experimental/progress_tracker.py
  29. 260 115
      hivemind/optim/experimental/state_averager.py
  30. 76 45
      hivemind/optim/grad_scaler.py
  31. 8 4
      hivemind/optim/simple.py
  32. 1 1
      hivemind/optim/training_averager.py
  33. 1 0
      hivemind/utils/__init__.py
  34. 23 9
      hivemind/utils/asyncio.py
  35. 10 3
      hivemind/utils/logging.py
  36. 5 3
      hivemind/utils/mpfuture.py
  37. 2 2
      hivemind/utils/serializer.py
  38. 1 0
      requirements-dev.txt
  39. 1 0
      requirements-docs.txt
  40. 213 0
      tests/test_allreduce_fault_tolerance.py
  41. 76 2
      tests/test_averaging.py
  42. 220 7
      tests/test_optimizer.py

+ 1 - 1
.github/workflows/check-style.yml

@@ -1,6 +1,6 @@
 name: Check style
 
-on: [ push ]
+on: [ push, pull_request ]
 
 jobs:
   black:

+ 35 - 0
.github/workflows/run-benchmarks.yml

@@ -0,0 +1,35 @@
+name: Benchmarks
+
+on: [ push, pull_request ]
+
+
+jobs:
+  run_benchmarks:
+
+    runs-on: ubuntu-latest
+    timeout-minutes: 10
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+        with:
+          python-version: 3.9
+      - name: Cache dependencies
+        uses: actions/cache@v2
+        with:
+          path: ~/.cache/pip
+          key: Key-v1-3.9-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements-dev.txt') }}
+      - name: Install dependencies
+        run: |
+          python -m pip install --upgrade pip
+          pip install -r requirements.txt
+          pip install -r requirements-dev.txt
+      - name: Build hivemind
+        run: |
+          pip install .
+      - name: Benchmark
+        run: |
+          cd benchmarks
+          python benchmark_throughput.py --preset minimalistic
+          python benchmark_tensor_compression.py
+          python benchmark_dht.py

+ 1 - 1
.github/workflows/run-tests.yml

@@ -1,6 +1,6 @@
 name: Tests
 
-on: [ push ]
+on: [ push, pull_request ]
 
 
 jobs:

+ 1 - 1
README.md

@@ -93,7 +93,7 @@ documentation improvements to entirely new features, is appreciated.
 
 If you want to contribute to hivemind but don't know where to start, take a look at the
 unresolved [issues](https://github.com/learning-at-home/hivemind/issues). Open a new issue or
-join [our chat room](https://discord.gg/xC7ucM8j) in case you want to discuss new functionality or report a possible
+join [our chat room](https://discord.gg/uGugx9zYvN) in case you want to discuss new functionality or report a possible
 bug. Bug fixes are always welcome, but new features should be preferably discussed with maintainers beforehand.
 
 If you want to start contributing to the source code of hivemind, please see

+ 178 - 77
benchmarks/benchmark_dht.py

@@ -1,11 +1,15 @@
 import argparse
+import asyncio
 import random
 import time
+import uuid
+from logging import shutdown
+from typing import Tuple
 
+import numpy as np
 from tqdm import trange
 
 import hivemind
-from hivemind.moe.server import declare_experts, get_experts
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 
@@ -13,23 +17,116 @@ use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 
 
-def random_endpoint() -> hivemind.Endpoint:
-    return (
-        f"{random.randint(0, 256)}.{random.randint(0, 256)}.{random.randint(0, 256)}."
-        f"{random.randint(0, 256)}:{random.randint(0, 65535)}"
-    )
+class NodeKiller:
+    """Auxiliary class that kills dht nodes over a pre-defined schedule"""
+
+    def __init__(self, shutdown_peers: list, shutdown_timestamps: list):
+        self.shutdown_peers = set(shutdown_peers)
+        self.shutdown_timestamps = shutdown_timestamps
+        self.current_iter = 0
+        self.timestamp_iter = 0
+        self.lock = asyncio.Lock()
+
+    async def check_and_kill(self):
+        async with self.lock:
+            if (
+                self.shutdown_timestamps != None
+                and self.timestamp_iter < len(self.shutdown_timestamps)
+                and self.current_iter == self.shutdown_timestamps[self.timestamp_iter]
+            ):
+                shutdown_peer = random.sample(self.shutdown_peers, 1)[0]
+                shutdown_peer.shutdown()
+                self.shutdown_peers.remove(shutdown_peer)
+                self.timestamp_iter += 1
+            self.current_iter += 1
+
+
+async def store_and_get_task(
+    peers: list,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
+    expiration: float,
+    latest: bool,
+    node_killer: NodeKiller,
+) -> Tuple[list, list, list, list, int, int]:
+    """Iteratively choose random peers to store data onto the dht, then retreive with another random subset of peers"""
+
+    total_stores = total_gets = 0
+    successful_stores = []
+    successful_gets = []
+    store_times = []
+    get_times = []
+
+    for _ in range(total_num_rounds):
+        key = uuid.uuid4().hex
+
+        store_start = time.perf_counter()
+        store_peers = random.sample(peers, min(num_store_peers, len(peers)))
+        store_subkeys = [uuid.uuid4().hex for _ in store_peers]
+        store_values = {subkey: uuid.uuid4().hex for subkey in store_subkeys}
+        store_tasks = [
+            peer.store(
+                key,
+                subkey=subkey,
+                value=store_values[subkey],
+                expiration_time=hivemind.get_dht_time() + expiration,
+                return_future=True,
+            )
+            for peer, subkey in zip(store_peers, store_subkeys)
+        ]
+        store_result = await asyncio.gather(*store_tasks)
+        await node_killer.check_and_kill()
+
+        store_times.append(time.perf_counter() - store_start)
+
+        total_stores += len(store_result)
+        successful_stores_per_iter = sum(store_result)
+        successful_stores.append(successful_stores_per_iter)
+        await asyncio.sleep(delay)
+
+        get_start = time.perf_counter()
+        get_peers = random.sample(peers, min(num_get_peers, len(peers)))
+        get_tasks = [peer.get(key, latest, return_future=True) for peer in get_peers]
+        get_result = await asyncio.gather(*get_tasks)
+        get_times.append(time.perf_counter() - get_start)
+
+        successful_gets_per_iter = 0
+
+        total_gets += len(get_result)
+        for result in get_result:
+            if result != None:
+                attendees, expiration = result
+                if len(attendees.keys()) == successful_stores_per_iter:
+                    get_ok = True
+                    for key in attendees:
+                        if attendees[key][0] != store_values[key]:
+                            get_ok = False
+                            break
+                    successful_gets_per_iter += get_ok
 
+        successful_gets.append(successful_gets_per_iter)
+        await asyncio.sleep(wait_after_iteration)
 
-def benchmark_dht(
+    return store_times, get_times, successful_stores, successful_gets, total_stores, total_gets
+
+
+async def benchmark_dht(
     num_peers: int,
     initial_peers: int,
-    num_experts: int,
-    expert_batch_size: int,
     random_seed: int,
-    wait_after_request: float,
-    wait_before_read: float,
+    num_threads: int,
+    total_num_rounds: int,
+    num_store_peers: int,
+    num_get_peers: int,
+    wait_after_iteration: float,
+    delay: float,
     wait_timeout: float,
     expiration: float,
+    latest: bool,
+    failure_rate: float,
 ):
     random.seed(random_seed)
 
@@ -42,88 +139,92 @@ def benchmark_dht(
         peer = hivemind.DHT(initial_peers=neighbors, start=True, wait_timeout=wait_timeout)
         peers.append(peer)
 
-    store_peer, get_peer = peers[-2:]
-
-    expert_uids = list(
-        set(
-            f"expert.{random.randint(0, 999)}.{random.randint(0, 999)}.{random.randint(0, 999)}"
-            for _ in range(num_experts)
-        )
-    )
-    logger.info(f"Sampled {len(expert_uids)} unique ids (after deduplication)")
-    random.shuffle(expert_uids)
-
-    logger.info(f"Storing experts to dht in batches of {expert_batch_size}...")
-    successful_stores = total_stores = total_store_time = 0
     benchmark_started = time.perf_counter()
-    endpoints = []
-
-    for start in trange(0, num_experts, expert_batch_size):
-        store_start = time.perf_counter()
-        endpoints.append(random_endpoint())
-        store_ok = declare_experts(
-            store_peer, expert_uids[start : start + expert_batch_size], endpoints[-1], expiration=expiration
+    logger.info("Creating store and get tasks...")
+    shutdown_peers = random.sample(peers, min(int(failure_rate * num_peers), num_peers))
+    assert len(shutdown_peers) != len(peers)
+    remaining_peers = list(set(peers) - set(shutdown_peers))
+    shutdown_timestamps = random.sample(
+        range(0, num_threads * total_num_rounds), min(len(shutdown_peers), num_threads * total_num_rounds)
+    )
+    shutdown_timestamps.sort()
+    node_killer = NodeKiller(shutdown_peers, shutdown_timestamps)
+    task_list = [
+        asyncio.create_task(
+            store_and_get_task(
+                remaining_peers,
+                total_num_rounds,
+                num_store_peers,
+                num_get_peers,
+                wait_after_iteration,
+                delay,
+                expiration,
+                latest,
+                node_killer,
+            )
         )
-        successes = store_ok.values()
-        total_store_time += time.perf_counter() - store_start
-
-        total_stores += len(successes)
-        successful_stores += sum(successes)
-        time.sleep(wait_after_request)
+        for _ in trange(num_threads)
+    ]
+
+    store_and_get_result = await asyncio.gather(*task_list)
+    benchmark_total_time = time.perf_counter() - benchmark_started
+    total_store_times = []
+    total_get_times = []
+    total_successful_stores = []
+    total_successful_gets = []
+    total_stores = total_gets = 0
+    for result in store_and_get_result:
+        store_times, get_times, successful_stores, successful_gets, stores, gets = result
+
+        total_store_times.extend(store_times)
+        total_get_times.extend(get_times)
+        total_successful_stores.extend(successful_stores)
+        total_successful_gets.extend(successful_gets)
+        total_stores += stores
+        total_gets += gets
 
+    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(
-        f"Store success rate: {successful_stores / total_stores * 100:.1f}% ({successful_stores} / {total_stores})"
+        f"Store wall time (sec.): mean({np.mean(total_store_times):.3f}) "
+        + f"std({np.std(total_store_times, ddof=1):.3f}) max({np.max(total_store_times):.3f})"
     )
-    logger.info(f"Mean store time: {total_store_time / total_stores:.5}, Total: {total_store_time:.5}")
-    time.sleep(wait_before_read)
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning("All keys expired before benchmark started getting them. Consider increasing expiration_time")
-
-    successful_gets = total_get_time = 0
-
-    for start in trange(0, len(expert_uids), expert_batch_size):
-        get_start = time.perf_counter()
-        get_result = get_experts(get_peer, expert_uids[start : start + expert_batch_size])
-        total_get_time += time.perf_counter() - get_start
-
-        for i, expert in enumerate(get_result):
-            if (
-                expert is not None
-                and expert.uid == expert_uids[start + i]
-                and expert.endpoint == endpoints[start // expert_batch_size]
-            ):
-                successful_gets += 1
-
-    if time.perf_counter() - benchmark_started > expiration:
-        logger.warning(
-            "keys expired midway during get requests. If that isn't desired, increase expiration_time param"
-        )
-
     logger.info(
-        f"Get success rate: {successful_gets / len(expert_uids) * 100:.1f} ({successful_gets} / {len(expert_uids)})"
+        f"Get wall time (sec.): mean({np.mean(total_get_times):.3f}) "
+        + f"std({np.std(total_get_times, ddof=1):.3f}) max({np.max(total_get_times):.3f})"
+    )
+    logger.info(f"Average store time per worker: {sum(total_store_times) / num_threads:.3f} sec.")
+    logger.info(f"Average get time per worker: {sum(total_get_times) / num_threads:.3f} sec.")
+    logger.info(f"Total benchmark time: {benchmark_total_time:.5f} sec.")
+    logger.info(
+        "Store success rate: "
+        + f"{sum(total_successful_stores) / total_stores * 100:.1f}% ({sum(total_successful_stores)}/{total_stores})"
+    )
+    logger.info(
+        "Get success rate: "
+        + f"{sum(total_successful_gets) / total_gets * 100:.1f}% ({sum(total_successful_gets)}/{total_gets})"
     )
-    logger.info(f"Mean get time: {total_get_time / len(expert_uids):.5f}, Total: {total_get_time:.5f}")
-
-    alive_peers = [peer.is_alive() for peer in peers]
     logger.info(f"Node survival rate: {len(alive_peers) / len(peers) * 100:.3f}%")
 
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
-    parser.add_argument("--num_peers", type=int, default=32, required=False)
-    parser.add_argument("--initial_peers", type=int, default=1, required=False)
-    parser.add_argument("--num_experts", type=int, default=256, required=False)
-    parser.add_argument("--expert_batch_size", type=int, default=32, required=False)
-    parser.add_argument("--expiration", type=float, default=300, required=False)
-    parser.add_argument("--wait_after_request", type=float, default=0, required=False)
-    parser.add_argument("--wait_before_read", type=float, default=0, required=False)
+    parser.add_argument("--num_peers", type=int, default=16, required=False)
+    parser.add_argument("--initial_peers", type=int, default=4, required=False)
+    parser.add_argument("--random_seed", type=int, default=30, required=False)
+    parser.add_argument("--num_threads", type=int, default=10, required=False)
+    parser.add_argument("--total_num_rounds", type=int, default=16, required=False)
+    parser.add_argument("--num_store_peers", type=int, default=8, required=False)
+    parser.add_argument("--num_get_peers", type=int, default=8, required=False)
+    parser.add_argument("--wait_after_iteration", type=float, default=0, required=False)
+    parser.add_argument("--delay", type=float, default=0, required=False)
     parser.add_argument("--wait_timeout", type=float, default=5, required=False)
-    parser.add_argument("--random_seed", type=int, default=random.randint(1, 1000))
+    parser.add_argument("--expiration", type=float, default=300, required=False)
+    parser.add_argument("--latest", type=bool, default=True, required=False)
+    parser.add_argument("--failure_rate", type=float, default=0.1, required=False)
     parser.add_argument("--increase_file_limit", action="store_true")
     args = vars(parser.parse_args())
 
     if args.pop("increase_file_limit", False):
         increase_file_limit()
 
-    benchmark_dht(**args)
+    asyncio.run(benchmark_dht(**args))

+ 163 - 0
benchmarks/benchmark_optimizer.py

@@ -0,0 +1,163 @@
+import multiprocessing as mp
+import random
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import numpy as np
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+
+import hivemind
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    run_id: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 256
+    reuse_grad_buffers: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    average_state_every: int = 1
+    use_amp: bool = False
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device: str = "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def benchmark_optimizer(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = train_dataset.data[0].numel()
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id=args.run_id,
+            target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
+            delay_optimizer_step=args.delay_optimizer_step,
+            average_state_every=args.average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        if args.use_amp and args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                daemon=False,
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()
+
+
+if __name__ == "__main__":
+    benchmark_optimizer(TrainingArguments())

+ 26 - 4
docs/modules/optim.rst

@@ -1,14 +1,36 @@
 **hivemind.optim**
 ==================
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 
-  This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
+  This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers.
+  Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent,
+  or perform asynchrnous local updates and average model parameters.
+
   <br><br>
 
+.. automodule:: hivemind.optim.experimental.optimizer
+.. currentmodule:: hivemind.optim.experimental.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, local_epoch, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :member-order: bysource

+ 34 - 31
docs/user/quickstart.md

@@ -47,26 +47,27 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
                       nn.Flatten(), nn.Linear(32 * 5 * 5, 10))
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
-
 # Create DHT: a decentralized key-value storage shared between peers
 dht = hivemind.DHT(start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch 
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 
+# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -78,7 +79,7 @@ with tqdm() as progressbar:
 
 
 As you can see, this code is regular PyTorch with one notable exception: it wraps your regular optimizer with a
-`DecentralizedOptimizer`. This optimizer uses `DHT` to find other peers and tries to exchange weights them. When you run
+`hivemind.Optimizer`. This optimizer uses `DHT` to find other peers and tries to exchange parameters them. When you run
 the code (please do so), you will see the following output:
 
 ```shell
@@ -86,7 +87,7 @@ To join the training, use initial_peers = ['/ip4/127.0.0.1/tcp/XXX/p2p/YYY']
 [...] Starting a new averaging round with current parameters.
 ```
 
-This is `DecentralizedOptimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
+This is `hivemind.Optimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create 
 them ourselves.
 
 Copy the entire script (or notebook) and modify this line:
@@ -123,26 +124,28 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
 opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
 
 # Create DHT: a decentralized key-value storage shared between peers
-dht = hivemind.DHT(initial_peers=[COPY_FROM_ANOTHER_PEER_OUTPUTS], start=True)
+dht = hivemind.DHT(initial_peers=[COPY_FROM_OTHER_PEERS_OUTPUTS], start=True)
 print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])
 
 # Set up a decentralized optimizer that will average with peers in background
-opt = hivemind.optim.DecentralizedOptimizer(
-    opt,                      # wrap the SGD optimizer defined above
-    dht,                      # use a DHT that is connected with other peers
-    average_parameters=True,  # periodically average model weights in opt.step
-    average_gradients=False,  # do not average accumulated gradients
-    prefix='my_cifar_run',    # unique identifier of this collaborative run
-    target_group_size=16,     # maximum concurrent peers for this run
+opt = hivemind.Optimizer(
+    dht=dht,                  # use a DHT that is connected with other peers
+    run_id='my_cifar_run',    # unique identifier of this collaborative run
+    batch_size_per_step=32,   # each call to opt.step adds this many samples towards the next epoch
+    target_batch_size=10000,  # after peers collectively process this many samples, average weights and begin the next epoch
+    optimizer=opt,            # wrap the SGD optimizer defined above
+    use_local_updates=True,   # perform optimizer steps with local gradients, average parameters in background
+    matchmaking_time=3.0,     # when averaging parameters, gather peers in background for up to this many seconds
+    averaging_timeout=10.0,   # give up on averaging if not successful in this many seconds
     verbose=True              # print logs incessently
 )
 
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 
-# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
+# Note: if you intend to use GPU, switch to it only after the optimizer is created
 with tqdm() as progressbar:
     while True:
-        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
+        for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
             opt.zero_grad()
             loss = F.cross_entropy(model(x_batch), y_batch)
             loss.backward()
@@ -166,22 +169,22 @@ This message means that the optimizer has averaged model parameters with another
 during one of the calls to `opt.step()`. You can start more peers by replicating the same code as the second peer,
 using either the first or second peer as `initial_peers`.
 
-The only issue with this code is that each new peer starts with a different untrained network blends its un-trained
-parameters with other peers, reseting their progress. You can see this effect as a spike increase in training loss
-immediately after new peer joins training. To avoid this problem, the second peer can download the
-current model/optimizer state from an existing peer right before it begins training on minibatches:
+Each new peer starts with an untrained network and must download the latest training state before it can contribute.
+By default, peer will automatically detect that it is out of sync and start ``Downloading parameters from peer <...>``.
+To avoid wasting the first optimizer step, one can manually download the latest model/optimizer state right before it begins training on minibatches:
 ```python
-opt.averager.load_state_from_peers()
+opt.load_state_from_peers()
 ```
 
 Congrats, you've just started a pocket-sized experiment with decentralized deep learning!
 
-However, this is just the bare minimum of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
+However, this is only the basics of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
 we show how to use a more advanced version of DecentralizedOptimizer to collaboratively train a large Transformer over the internet.
 
 If you want to learn more about each individual component,
 - Learn how to use `hivemind.DHT` using this basic [DHT tutorial](https://learning-at-home.readthedocs.io/en/latest/user/dht.html),
-- Learn the underlying math behind DecentralizedOptimizer in
-  [(Li et al. 2020)](https://arxiv.org/abs/2005.00124) and [(Ryabinin et al. 2021)](https://arxiv.org/abs/2103.03239).
+- Read more on how to use `hivemind.Optimizer` in its [documentation page](https://learning-at-home.readthedocs.io/en/latest/modules/optim.html), 
+- Learn the underlying math behind hivemind.Optimizer in [Diskin et al., (2021)](https://arxiv.org/abs/2106.10207), 
+  [Li et al. (2020)](https://arxiv.org/abs/2005.00124) and [Ryabinin et al. (2021)](https://arxiv.org/abs/2103.03239).
 - Read about setting up Mixture-of-Experts training in [this guide](https://learning-at-home.readthedocs.io/en/latest/user/moe.html),
  

+ 2 - 0
hivemind/__init__.py

@@ -16,6 +16,8 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    GradScaler,
+    Optimizer,
     TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo

+ 139 - 75
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
+from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
@@ -11,8 +11,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import get_logger
 from hivemind.utils.asyncio import (
     achain,
-    aenumerate,
-    afirst,
+    aiter_with_timeout,
     amap_in_executor,
     anext,
     as_aiter,
@@ -52,6 +51,10 @@ class AllReduceRunner(ServicerBase):
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param gathered: additional user-defined data collected from this group
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 x sender_timeout
     :param kwargs: additional parameters (e.g. part_size_bytes) will be passed to TensorPartContainer
     :note: Full-mode peers send and receive tensor parts concurrently, assuming a full-duplex TCP stream. In turn,
       non-averaging peers receive results only after they finish sending, which helps them avoid
@@ -71,11 +74,18 @@ class AllReduceRunner(ServicerBase):
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         **kwargs,
     ):
         self._p2p = p2p
         self.peer_id = p2p.peer_id
         assert self.peer_id in ordered_peer_ids, "peer_id is not a part of the group"
+        if reducer_timeout is not None and (sender_timeout is None or reducer_timeout <= sender_timeout):
+            raise ValueError(
+                "If reducer_timeout is enabled, sender_timeout must be shorter than reducer_timeout. "
+                "Otherwise, there is a chance that reducers will be banned while they await senders."
+            )
 
         if not issubclass(servicer_type, ServicerBase):
             raise TypeError("`servicer_type` is expected to be a ServicerBase subclass")
@@ -102,8 +112,19 @@ class AllReduceRunner(ServicerBase):
             if mode != AveragingMode.AUX:
                 self.sender_peer_ids.append(peer_id)
 
+        self.sender_timeout, self.reducer_timeout = sender_timeout, reducer_timeout
+        self.all_senders_started = asyncio.Event()
+        self.banned_senders: Set[PeerID] = set()  # peers that did not send data by next_chunk_timeout
+        self.banlock = asyncio.Lock()
+
+        self.active_senders: Set[PeerID] = set()  # peers that began sending data via rpc_aggregate_part
+        if self.peer_id in self.sender_peer_ids:
+            self.active_senders.add(self.peer_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
+
         peer_id_index = self.ordered_peer_ids.index(self.peer_id)
-        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, **kwargs)
+        self.tensor_part_container = TensorPartContainer(tensors, peer_fractions, return_deltas=True, **kwargs)
         self.parts_for_local_averaging = self.tensor_part_container.get_raw_input_parts(peer_id_index)
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
@@ -132,6 +153,10 @@ class AllReduceRunner(ServicerBase):
     async def run(self) -> AsyncIterator[torch.Tensor]:
         """Run all-reduce, return differences between averaged and original tensors as they are computed"""
         pending_tasks = set()
+
+        if self.tensor_part_container.num_parts_by_peer[self.ordered_peer_ids.index(self.peer_id)] != 0:
+            pending_tasks.add(asyncio.create_task(self._handle_missing_senders()))
+
         try:
             if len(self.sender_peer_ids) == 0:
                 logger.debug(f"{self} - finished all-reduce early: all peers are auxiliaries ({self.modes})")
@@ -144,6 +169,7 @@ class AllReduceRunner(ServicerBase):
 
                 async for averaged_tensor_delta in self.tensor_part_container.iterate_output_tensors():
                     yield averaged_tensor_delta  # delta = averaged_tensor - original_tensor
+
                 self.finalize()
 
             else:  # auxiliary peer
@@ -156,6 +182,24 @@ class AllReduceRunner(ServicerBase):
                 task.cancel()
             raise
 
+        finally:
+            for task in pending_tasks:
+                try:
+                    await task
+                except asyncio.CancelledError:
+                    pass
+                except Exception as inner_exc:
+                    logger.debug(f"Task {task} failed with {inner_exc}", exc_info=True)
+
+    async def _handle_missing_senders(self):
+        """Detect senders that should have sent tensors for averaging, but did not send anything within timeout"""
+        try:
+            await asyncio.wait_for(self.all_senders_started.wait(), self.sender_timeout)
+        except asyncio.TimeoutError:
+            for peer_id in self.sender_peer_ids:
+                if peer_id not in self.active_senders and peer_id not in self.banned_senders:
+                    await self._ban_sender(peer_id)
+
     async def _communicate_with_peer(self, peer_id: PeerID):
         """Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors"""
         peer_index = self.ordered_peer_ids.index(peer_id)
@@ -168,25 +212,39 @@ class AllReduceRunner(ServicerBase):
                 self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part - tensor_part)
 
         else:
-            code = None
-            stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
-            async for part_index, (averaged_part_delta, msg) in aenumerate(
-                amap_in_executor(
-                    lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg),
-                    stream,
+            try:
+                done_sending = asyncio.Event()
+                inputs_aiter = attach_event_on_finished(self._generate_input_for_peer(peer_index), done_sending)
+                stream = await self._get_peer_stub(peer_id).rpc_aggregate_part(inputs_aiter)
+
+                if self.should_delay_results(self.peer_id):
+                    await done_sending.wait()
+
+                part_index = 0
+
+                def _try_deserialize(msg):
+                    if msg.code != averaging_pb2.AVERAGED_PART:
+                        raise AllreduceException(f"{peer_id} sent {averaging_pb2.MessageCode.Name(msg.code)}")
+                    return deserialize_torch_tensor(msg.tensor_part), msg
+
+                async for delta, msg in amap_in_executor(
+                    _try_deserialize,
+                    aiter_with_timeout(stream, self.reducer_timeout),
                     max_prefetch=self.tensor_part_container.prefetch,
-                )
-            ):
-                if code is None:
-                    code = msg.code
-                self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
-
-            if code != averaging_pb2.AVERAGED_PART:
-                raise AllreduceException(
-                    f"peer {peer_id} returned {averaging_pb2.MessageCode.Name(code)} "
-                    f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
-                    f", allreduce failed"
-                )
+                ):
+                    self.tensor_part_container.register_processed_part(peer_index, part_index, delta)
+                    part_index += 1
+
+                if part_index != self.tensor_part_container.num_parts_by_peer[peer_index]:
+                    raise AllreduceException(
+                        f"peer {peer_id} sent {part_index} parts, but we expected "
+                        f"{self.tensor_part_container.num_parts_by_peer[peer_index]}"
+                    )
+            except BaseException as e:
+                if isinstance(e, Exception):
+                    logger.warning(f"Caught {repr(e)} when communicating to {peer_id}")
+                self.tensor_part_container.register_failed_reducer(peer_index)
+                raise
 
     async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
         parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
@@ -204,18 +262,22 @@ class AllReduceRunner(ServicerBase):
         self, stream: AsyncIterator[averaging_pb2.AveragingData], context: P2PContext
     ) -> AsyncIterator[averaging_pb2.AveragingData]:
         """a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
-        request: averaging_pb2.AveragingData = await anext(stream)
-        reason_to_reject = self._check_reasons_to_reject(request)
-        if reason_to_reject:
-            yield reason_to_reject
-            return
-
-        elif request.code == averaging_pb2.PART_FOR_AVERAGING:
-            try:
-                sender_index = self.sender_peer_ids.index(context.remote_id)
+        sender_index = self.sender_peer_ids.index(context.remote_id)
+        self.active_senders.add(context.remote_id)
+        if len(self.active_senders) == len(self.sender_peer_ids):
+            self.all_senders_started.set()
 
+        try:
+            request: averaging_pb2.AveragingData = await asyncio.wait_for(anext(stream), self.sender_timeout)
+            reason_to_reject = self._check_reasons_to_reject(request, context)
+            if reason_to_reject:
+                yield reason_to_reject
+                return
+
+            elif request.code == averaging_pb2.PART_FOR_AVERAGING:
+                stream = aiter_with_timeout(achain(as_aiter(request), stream), self.sender_timeout)
                 if not self.should_delay_results(context.remote_id):
-                    async for msg in self._accumulate_parts_streaming(achain(as_aiter(request), stream), sender_index):
+                    async for msg in self._accumulate_parts_streaming(stream, sender_index):
                         yield msg
 
                 else:
@@ -223,10 +285,13 @@ class AllReduceRunner(ServicerBase):
                     delayed_results = asyncio.Queue()
 
                     async def _accumulate_parts():
-                        inputs_aiter = attach_event_on_finished(achain(as_aiter(request), stream), done_receiving)
-                        async for msg in self._accumulate_parts_streaming(inputs_aiter, sender_index):
-                            delayed_results.put_nowait(msg)
-                        delayed_results.put_nowait(None)
+                        try:
+                            async for msg in self._accumulate_parts_streaming(
+                                attach_event_on_finished(stream, done_receiving), sender_index
+                            ):
+                                delayed_results.put_nowait(msg)
+                        finally:
+                            delayed_results.put_nowait(None)
 
                     accumulate_task = asyncio.create_task(_accumulate_parts())
 
@@ -239,60 +304,61 @@ class AllReduceRunner(ServicerBase):
                         yield next_result
                     await accumulate_task
 
-            except Exception as e:
-                self.finalize(exception=e)
+            else:
                 yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
-        else:
-            error_code = averaging_pb2.MessageCode.Name(request.code)
-            logger.debug(f"{self} - peer {context.remote_id} sent {error_code}, allreduce cannot continue")
-            self.finalize(exception=AllreduceException(f"peer {context.remote_id} sent {error_code}."))
-            yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                raise AllreduceException(f"{context.remote_id} sent {averaging_pb2.MessageCode.Name(request.code)}")
+
+        except BaseException as e:
+            await self._ban_sender(context.remote_id)
+            if isinstance(e, Exception):
+                logger.warning(f"Caught {repr(e)} when communicating with {context.remote_id}")
+                yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+            else:
+                raise  # CancelledError, StopIteration and similar
+
+    async def _ban_sender(self, peer_id: PeerID):
+        async with self.banlock:
+            if peer_id not in self.banned_senders:
+                self.banned_senders.add(peer_id)
+                self.tensor_part_reducer.on_sender_failed(self.sender_peer_ids.index(peer_id))
 
-    def _check_reasons_to_reject(self, request: averaging_pb2.AveragingData) -> Optional[averaging_pb2.AveragingData]:
+    def _check_reasons_to_reject(
+        self, request: averaging_pb2.AveragingData, context: P2PContext
+    ) -> Optional[averaging_pb2.AveragingData]:
         if request.group_id != self.group_id:
             return averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
         elif self._future.cancelled():
             return averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
         elif self._future.done():
             return averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+        elif context.remote_id not in self.sender_peer_ids:
+            return averaging_pb2.AveragingData(code=averaging_pb2.PROTOCOL_VIOLATION)
 
     async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.AveragingData], sender_index: int):
-        loop = asyncio.get_event_loop()
-        async for part_index, (tensor_part, weight, part_compression) in aenumerate(
-            amap_in_executor(
+        part_index = 0
+        try:
+            loop = asyncio.get_event_loop()
+            async for tensor_part, weight, part_compression in amap_in_executor(
                 lambda msg: (deserialize_torch_tensor(msg.tensor_part), msg.weight, msg.tensor_part.compression),
                 stream,
                 max_prefetch=self.tensor_part_container.prefetch,
-            )
-        ):
-            averaged_part = await self.tensor_part_reducer.accumulate_part(
-                sender_index, part_index, tensor_part, weight=weight
-            )
-
-            serialized_delta = await loop.run_in_executor(
-                None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
-            )
-            yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+            ):
+                averaged_part = await self.tensor_part_reducer.accumulate_part(
+                    sender_index, part_index, tensor_part, weight=weight
+                )
+                part_index += 1
 
-    async def _send_error_to_peer(self, peer_id: PeerID, code: averaging_pb2.MessageCode):
-        error = averaging_pb2.AveragingData(group_id=self.group_id, code=code)
-        await afirst(await self._get_peer_stub(peer_id).rpc_aggregate_part(as_aiter(error)))
+                serialized_delta = await loop.run_in_executor(
+                    None, lambda: serialize_torch_tensor(averaged_part - tensor_part, part_compression)
+                )
+                yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)
+        finally:
+            if part_index != self.tensor_part_reducer.num_parts:
+                await self._ban_sender(self.sender_peer_ids[sender_index])
 
     def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
         """finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
         assert not cancel or not exception, "finalize accepts either exception or cancel, but not both"
-        pending_tasks = set()
-        if cancel or exception:
-            # propagate error to peers
-            if cancel or isinstance(exception, asyncio.CancelledError):
-                code = averaging_pb2.CANCELLED
-            else:
-                code = averaging_pb2.INTERNAL_ERROR
-            logger.debug(f"{self} - notifying peers about {averaging_pb2.MessageCode.Name(code)}")
-            for peer_id, mode in zip(self.ordered_peer_ids, self.modes):
-                if peer_id != self.peer_id and mode != AveragingMode.CLIENT:
-                    pending_tasks.add(asyncio.create_task(self._send_error_to_peer(peer_id, code)))
-
         if not self._future.done():
             if cancel:
                 logger.debug(f"{self} - cancelled")
@@ -305,7 +371,5 @@ class AllReduceRunner(ServicerBase):
                 self._future.set_result(None)
             self.tensor_part_container.finalize()
             self.tensor_part_reducer.finalize()
-            return pending_tasks
         else:
-            logger.debug(f"{self} - could not finish: allreduce is already finished: {self._future}")
-            return pending_tasks
+            logger.debug(f"{self} - attempted to finalize allreduce that is already finished: {self._future}")

+ 105 - 45
hivemind/averaging/averager.py

@@ -7,6 +7,7 @@ import contextlib
 import ctypes
 import multiprocessing as mp
 import os
+import random
 import threading
 import weakref
 from dataclasses import asdict
@@ -30,10 +31,12 @@ from hivemind.compression import (
 )
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
+    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -67,7 +70,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param compression: optionally compress tensors with this compression algorithm before running all-reduce
     :param state_compression: a separate compression strategy for load_state_from_peers (default = no compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
-    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
     :param averaging_alpha: optional "learning rate" for averaging. If specified, local parameters will be shifted
       towards the (estimated) average by this coefficient. By default, local parameters are set equal to average.
     :param request_timeout: when looking for group, wait for a response from leader for at most this many seconds.
@@ -84,6 +86,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     :param allow_state_sharing: if set to True, other peers can download this peer's state. Can be overwritten
       with averager.allow_state_sharing = True / False
     :param declare_state_period: re-declare averager as a donor for load_state_from_peers every this many seconds
+    :param allreduce_timeout: spend at most this many seconds for allreduce (after group is formed)
+    :param next_chunk_timeout: during all-reduce and load_state_from_peers, if peer does not send next data chunk in
+      this number of seconds, consider it failed and proceed with remaining peers. default: no timeout
+    :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
+      previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
+    :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
+      from previous chunk will be marked as failed and excluded from averaging. default: 2 * sender_timeout
     :param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
 
     Example:
@@ -112,7 +121,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         *,
         start: bool,
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         initial_group_bits: str = "",
         averaging_expiration: Optional[float] = None,
@@ -121,6 +130,9 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         averaging_alpha: float = 1.0,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        sender_timeout: Optional[float] = None,
+        reducer_timeout: Optional[float] = None,
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
@@ -137,8 +149,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert bandwidth is None or (
             bandwidth >= 0 and np.isfinite(np.float32(bandwidth))
         ), "bandwidth must be a non-negative float32"
-        if not is_power_of_two(target_group_size):
-            logger.warning("It is recommended to set target_group_size to a power of 2.")
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
@@ -153,6 +163,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         if client_mode is None:
             client_mode = dht.client_mode
+        if sender_timeout is None:
+            sender_timeout = next_chunk_timeout
+        if reducer_timeout is None:
+            reducer_timeout = 2 * sender_timeout if sender_timeout is not None else None
+
         self.client_mode = client_mode
 
         self._parent_pid = os.getpid()
@@ -166,13 +181,13 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
         self._averaged_tensors = tuple(averaged_tensors)
         self.lock_averaged_tensors = mp.Lock()
-        self.last_updated: DHTExpiration = -float("inf")
         for tensor in self._averaged_tensors:
             assert tensor.grad_fn is None, "averaged_tensors must be either parameters or leaf tensors"
             tensor.share_memory_()
         self.total_size = sum(map(torch.Tensor.numel, self._averaged_tensors))
         self.schema_hash = compute_schema_hash(self._averaged_tensors)
         self.shutdown_timeout = shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
         self.bandwidth = bandwidth
 
         self.matchmaking_kwargs = dict(
@@ -188,6 +203,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             compression=compression,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
+            sender_timeout=sender_timeout,
+            reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
@@ -195,6 +212,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
         self._allow_state_sharing = mp.Value(ctypes.c_bool, 0)
+        self._state_sharing_priority = mp.Value(ctypes.c_double, 0)
+
         if allow_state_sharing is None:
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
@@ -221,9 +240,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     @allow_state_sharing.setter
     def allow_state_sharing(self, value: bool):
         if value and self.client_mode:
-            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state.")
+            raise ValueError("Cannot allow state sharing: averager in client mode cannot share its state")
+        else:
+            old_value, self._allow_state_sharing.value = self._allow_state_sharing.value, value
+            if value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    @property
+    def state_sharing_priority(self) -> float:
+        """Others will preferentially downloading state from peers with highest priority."""
+        return float(self._state_sharing_priority.value)
+
+    @state_sharing_priority.setter
+    def state_sharing_priority(self, value: float):
+        if value and self.client_mode:
+            raise ValueError("State sharing priority is unused: averager in client mode cannot share its state")
         else:
-            self._allow_state_sharing.value = value
+            old_value, self._state_sharing_priority.value = self._state_sharing_priority.value, value
+            if self.allow_state_sharing and value != old_value:
+                self._outer_pipe.send(("_trigger_declare_load_state", [], {}))
+
+    async def _trigger_declare_load_state(self):
+        # note: previously tried to set mp.Event instead of this. Awaiting it in executor caused degradation in py39
+        self._state_updated.set()
 
     @property
     def peer_id(self) -> PeerID:
@@ -257,7 +296,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 if not self.client_mode:
                     await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
                 else:
-                    logger.debug(f"The averager is running in client mode.")
+                    logger.debug("The averager is running in client mode")
 
                 self._matchmaking = Matchmaking(
                     self._p2p,
@@ -318,7 +357,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             self._inner_pipe.send(("_SHUTDOWN", None))  # shut down background thread in master
             self.join(self.shutdown_timeout)
             if self.is_alive():
-                logger.warning("Averager did not shut down within the grace period; terminating it the hard way.")
+                logger.warning("Averager did not shut down within the grace period; terminating it the hard way")
                 self.terminate()
         else:
             logger.exception("Averager shutdown has no effect: the process is already not alive")
@@ -359,7 +398,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         :returns: on success, update averaged_tensors and return group info; on failure, return None
         """
         if self.mode == AveragingMode.AUX and weight is not None:
-            logger.warning("Averager is running in auxiliary mode, weight is unused.")
+            logger.warning("Averager is running in auxiliary mode, weight is unused")
         if scheduled_time is None:
             scheduled_time = get_dht_time() + self.matchmaking_kwargs["min_matchmaking_time"]
         if weight is None:
@@ -379,31 +418,45 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             data_for_gather=data_for_gather,
         )
 
-        future_for_trigger = MPFuture()
-        self._outer_pipe.send(("_step", [], dict(step=step, future_for_trigger=future_for_trigger)))
-        step.attach_trigger(future_for_trigger.result())
+        future_for_init = MPFuture()
+        self._outer_pipe.send(("_step", [], dict(step=step, future_for_init=future_for_init)))
+        step.attach(*future_for_init.result())
 
         if not require_trigger:
             step.allow_allreduce()
         return step.result() if wait else step
 
-    async def _step(self, *, step: StepControl, future_for_trigger: MPFuture):
+    async def _step(self, *, step: StepControl, future_for_init: MPFuture):
         try:
-            trigger = MPFuture()
-            step.attach_trigger(trigger)
-            future_for_trigger.set_result(trigger)
+            trigger, cancel = MPFuture(), MPFuture()
+            step.attach(trigger, cancel)
+            future_for_init.set_result((trigger, cancel))
+
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                if not step.triggered:
+                    step.stage = AveragingStage.AWAITING_TRIGGER
+                    await step.wait_for_trigger()
+                return group_info
 
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    group_info = await self._matchmaking.look_for_group(step)
-                    if group_info is None:
-                        raise AllreduceException("Averaging step failed: could not find a group.")
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
+                    check_cancel_task = asyncio.create_task(step.wait_for_cancel())
 
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
+                    await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
+                    if step.cancelled():
+                        matchmaking_task.cancel()
+                        raise asyncio.CancelledError()
+                    else:
+                        check_cancel_task.cancel()
+
+                    group_info = await matchmaking_task
+
+                    if group_info is None:
+                        raise AllreduceException("Averaging step failed: could not find a group")
 
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
 
@@ -425,10 +478,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.CancelledError,
                     asyncio.InvalidStateError,
                     P2PHandlerError,
+                    DispatchFailure,
+                    ControlFailure,
                 ) as e:
-                    if not step.allow_retries or get_dht_time() >= step.deadline:
-                        logger.exception(e)
-                        step.set_exception(e)
+                    if step.done() or not step.allow_retries or get_dht_time() >= step.deadline:
+                        if not step.cancelled():
+                            logger.exception(e)
+                        if not step.done():
+                            step.set_exception(e)
                     else:
                         logger.warning(f"{self.__class__.__name__} caught {repr(e)}, retrying")
 
@@ -477,11 +534,11 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                 with self.register_allreduce_group(group_info.group_id, allreduce):
                     if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                        iter_results = allreduce.run()
+                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
                             # all-reduce is performed asynchronously while iterating
                             tensor.add_(update, alpha=self._averaging_alpha)
-                            self.last_updated = get_dht_time()
-                            self._state_updated.set()
+                        self._state_updated.set()
 
                     else:
                         async for _ in allreduce:  # trigger all-reduce by iterating
@@ -489,7 +546,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
                 return allreduce.gathered
         except BaseException as e:
-            logger.exception(e)
+            if isinstance(e, Exception):
+                logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
     @contextlib.contextmanager
@@ -540,24 +598,29 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
     async def _declare_for_download_periodically(self):
         download_key = f"{self._matchmaking.group_key_manager.prefix}.all_averagers"
+        sharing_was_allowed = self.allow_state_sharing
         while True:
-            if self.allow_state_sharing:
-                self._state_updated.clear()
-                expiration_time = get_dht_time() + self.declare_state_period
+            expiration_time = get_dht_time() + self.declare_state_period
+            if self.allow_state_sharing or sharing_was_allowed:
+                # notify either if sharing is allowed or if it was just switched off (to overwrite previous message)
                 asyncio.create_task(
                     asyncio.wait_for(
                         self.dht.store(
                             download_key,
                             subkey=self.peer_id.to_bytes(),
-                            value=self.last_updated,
+                            value=self.state_sharing_priority if self.allow_state_sharing else None,
                             expiration_time=expiration_time,
                             return_future=True,
                         ),
-                        timeout=expiration_time - self.request_timeout,
+                        timeout=expiration_time - get_dht_time(),
                     )
                 )
+                sharing_was_allowed = self.allow_state_sharing
+
+            # report again either in state_declare_period or after the field was changed by the user
+            self._state_updated.clear()
             try:
-                await asyncio.wait_for(self._state_updated.wait(), self.declare_state_period - self.request_timeout)
+                await asyncio.wait_for(self._state_updated.wait(), timeout=max(0.0, expiration_time - get_dht_time()))
             except asyncio.TimeoutError:
                 pass
 
@@ -618,17 +681,19 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         return future.result(timeout=timeout) if wait else future
 
     async def _load_state_from_peers(self, future: MPFuture, timeout: Optional[float] = None):
+        if timeout is not None:
+            timeout = self.next_chunk_timeout if self.next_chunk_timeout is not None else self.request_timeout
         try:
             key_manager = self._matchmaking.group_key_manager
             peer_priority, _ = self.dht.get(f"{key_manager.prefix}.all_averagers", latest=True) or ({}, None)
             peer_priority = {
-                PeerID(peer_id): float(info.value)
+                PeerID(peer_id): (float(info.value), random.random())  # using randomness as a tie breaker
                 for peer_id, info in peer_priority.items()
                 if isinstance(info, ValueWithExpiration) and isinstance(info.value, (float, int))
             }
 
             if not isinstance(peer_priority, dict) or len(peer_priority) == 0:
-                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}.")
+                logger.info(f"Averager could not load state from peers: peer dict empty or corrupted {peer_priority}")
                 future.set_result(None)
                 return
 
@@ -641,7 +706,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         stream = await stub.rpc_download_state(averaging_pb2.DownloadRequest())
                         current_tensor_parts, tensors = [], []
 
-                        async for message in aiter_with_timeout(stream, timeout=timeout or self.request_timeout):
+                        async for message in aiter_with_timeout(stream, timeout=timeout):
                             if message.metadata:
                                 metadata = self.serializer.loads(message.metadata)
                             if message.tensor_part.dtype and current_tensor_parts:
@@ -653,7 +718,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                             tensors.append(deserialize_torch_tensor(combine_from_streaming(current_tensor_parts)))
 
                         if not metadata:
-                            logger.debug(f"Peer {peer} did not send its state.")
+                            logger.debug(f"Peer {peer} did not send its state")
                             continue
 
                         logger.info(f"Finished downloading state from {peer}")
@@ -697,11 +762,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                 future.set_exception(e)
 
 
-def is_power_of_two(n):
-    """Check whether n is a power of 2"""
-    return (n != 0) and (n & (n - 1) == 0)
-
-
 def _background_thread_fetch_current_state(
     serializer: SerializerBase, pipe: mp.connection.Connection, get_current_state_ref: weakref.WeakMethod
 ):

+ 24 - 7
hivemind/averaging/control.py

@@ -1,3 +1,4 @@
+import os
 import struct
 from enum import Enum
 from typing import Optional
@@ -43,6 +44,7 @@ class StepControl(MPFuture):
         super().__init__()
         self._data_for_gather, self._deadline, self._allow_retries = data_for_gather, deadline, allow_retries
         self._trigger: Optional[MPFuture] = None
+        self._cancel: Optional[MPFuture] = None
 
         # Buffer contents:
         # scheduled_time (double) | weight (double) | stage (AveragingStage, 1 byte) | began_allreduce: (bool, 1 byte)
@@ -52,12 +54,12 @@ class StepControl(MPFuture):
         self.weight = weight
         self.began_allreduce = False
 
-    def attach_trigger(self, trigger: MPFuture):
-        assert self._trigger is None, "Trigger is already attached"
-        self._trigger = trigger
+    def attach(self, trigger: MPFuture, cancel: MPFuture):
+        assert self._trigger is None and self._cancel is None, "Futures are already attached"
+        self._trigger, self._cancel = trigger, cancel
 
     def allow_allreduce(self):
-        """Allow averager to begin allreduce when it finds a group. Meant to be triggered by user."""
+        """Allow averager to begin all-reduce when it finds a group. Meant to be triggered by user."""
         assert self._trigger is not None, "StepControl does not have an attached trigger"
         if self._trigger.done():
             logger.warning("Trigger is already set")
@@ -82,7 +84,7 @@ class StepControl(MPFuture):
         if self.began_allreduce:
             logger.warning("Changing scheduled time has no effect after all-reduce has already started")
         if scheduled_time >= self.deadline:
-            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout.")
+            logger.warning("Changing scheduled time to after deadline, averaging will likely fail due to timeout")
         struct.pack_into("d", self._shared_buffer[StepControl._SCHEDULED_TIME].numpy().data, 0, float(scheduled_time))
 
     @property
@@ -133,16 +135,31 @@ class StepControl(MPFuture):
         return dict(
             super().__getstate__(),
             _trigger=self._trigger,
+            _cancel=self._cancel,
             _shared_buffer=self._shared_buffer,
             immutable_params=(self._data_for_gather, self._deadline, self._allow_retries),
         )
 
     def __setstate__(self, state):
         super().__setstate__(state)
-        self._trigger, self._shared_buffer = state["_trigger"], state["_shared_buffer"]
+        self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
 
+    def __del__(self):
+        if os.getpid() == self._origin_pid and not self.triggered:
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
+        super().__del__()
+
     def cancel(self) -> bool:
         if self._trigger is not None:
             self._trigger.cancel()
-        return self.cancel()
+        if self._cancel is not None:
+            self._cancel.set_result(None)
+        return super().cancel()
+
+    async def wait_for_cancel(self):
+        """Await for step to be cancelled by the user. Should be called from insider the averager."""
+        await self._cancel

+ 21 - 4
hivemind/averaging/key_manager.py

@@ -11,6 +11,7 @@ from hivemind.utils import DHTExpiration, ValueWithExpiration, get_dht_time, get
 
 GroupKey = str
 GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$")  # e.g. bert_exp4_averaging.0b01001101
+DEFAULT_NUM_BUCKETS = 256
 logger = get_logger(__name__)
 
 
@@ -29,9 +30,12 @@ class GroupKeyManager:
         dht: DHT,
         prefix: str,
         initial_group_bits: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
     ):
         assert all(bit in "01" for bit in initial_group_bits)
+        if target_group_size is not None and not is_power_of_two(target_group_size):
+            logger.warning("It is recommended to set target_group_size to a power of 2")
+
         self.dht, self.prefix, self.group_bits = dht, prefix, initial_group_bits
         self.target_group_size = target_group_size
         self.peer_id = dht.peer_id
@@ -76,7 +80,7 @@ class GroupKeyManager:
         assert is_valid_group(group_key), f"Group key {group_key} is invalid, must follow {GROUP_PATTERN}"
         result = await self.dht.get(group_key, latest=True, return_future=True)
         if result is None or not isinstance(result.value, dict):
-            logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
+            logger.debug(f"Allreduce group not found: {group_key}, creating new group")
             return []
         averagers = []
         for key, looking_for_group in result.value.items():
@@ -92,8 +96,11 @@ class GroupKeyManager:
         """this function is triggered every time an averager finds an allreduce group"""
         rng = random.Random(group_info.group_id)
         index = group_info.peer_ids.index(self.peer_id)
-        generalized_index = rng.sample(range(self.target_group_size), group_info.group_size)[index]
-        nbits = int(np.ceil(np.log2(self.target_group_size)))
+        num_buckets = self.target_group_size
+        if num_buckets is None:
+            num_buckets = next_power_of_two(group_info.group_size)
+        generalized_index = rng.sample(range(num_buckets), group_info.group_size)[index]
+        nbits = int(np.ceil(np.log2(num_buckets)))
         new_bits = bin(generalized_index)[2:].rjust(nbits, "0")
         self.group_bits = (self.group_bits + new_bits)[-len(self.group_bits) :] if self.group_bits else ""
         logger.debug(f"{self.peer_id} - updated group key to {self.group_bits}")
@@ -101,3 +108,13 @@ class GroupKeyManager:
     async def update_key_on_not_enough_peers(self):
         """this function is triggered whenever averager fails to assemble group within timeout"""
         pass  # to be implemented in subclasses
+
+
+def is_power_of_two(n):
+    """Check whether n is a power of 2"""
+    return (n != 0) and (n & (n - 1) == 0)
+
+
+def next_power_of_two(n):
+    """Round n up to the nearest power of 2"""
+    return 1 if n == 0 else 2 ** (n - 1).bit_length()

+ 1 - 1
hivemind/averaging/load_balancing.py

@@ -80,7 +80,7 @@ def optimize_parts_lp(vector_size: int, bandwidths: np.ndarray, min_size: int =
             peer_scores[peer_scores < min_size / float(vector_size)] = 0.0
         peer_scores = np.round(peer_scores, LOAD_BALANCING_LP_DECIMALS)
     else:
-        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}.")
+        logger.error(f"Failed to solve load-balancing for bandwidths {bandwidths}")
         peer_scores = np.ones(group_size, c.dtype)
 
     return peer_scores[np.argsort(permutation)]

+ 25 - 12
hivemind/averaging/matchmaking.py

@@ -16,6 +16,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.key_manager import GroupKey, GroupKeyManager
 from hivemind.dht import DHT, DHTID, DHTExpiration
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.proto import averaging_pb2
 from hivemind.utils import TimedStorage, get_dht_time, get_logger, timed_storage
 from hivemind.utils.asyncio import anext, cancel_and_wait
@@ -44,7 +45,7 @@ class Matchmaking:
         *,
         servicer_type: Type[ServicerBase],
         prefix: str,
-        target_group_size: int,
+        target_group_size: Optional[int],
         min_group_size: int,
         min_matchmaking_time: float,
         request_timeout: float,
@@ -88,9 +89,11 @@ class Matchmaking:
     async def looking_for_group(self, step_control: StepControl):
         async with self.lock_looking_for_group:
             assert self.step_control is None
-            self.step_control = step_control
-            yield
-            self.step_control = None
+            try:
+                self.step_control = step_control
+                yield
+            finally:
+                self.step_control = None
 
     @property
     def is_looking_for_group(self):
@@ -225,7 +228,10 @@ class Matchmaking:
                     if suggested_leader != self.peer_id:
                         logger.debug(f"{self} - leader disbanded group and redirected us to {suggested_leader}")
                         self.current_leader = None
-                        await stream.aclose()
+                        try:
+                            await stream.aclose()
+                        except RuntimeError as e:
+                            logger.debug(e, exc_info=True)
                         return await self._request_join_group(suggested_leader)
                 logger.debug(f"{self} - leader disbanded group")
                 return None
@@ -235,15 +241,18 @@ class Matchmaking:
         except asyncio.TimeoutError:
             logger.debug(f"{self} - potential leader {leader} did not respond within {self.request_timeout}")
             return None
-        except (P2PHandlerError, StopAsyncIteration) as e:
-            logger.exception(f"{self} - failed to request potential leader {leader}:")
+        except (P2PHandlerError, ControlFailure, DispatchFailure, StopAsyncIteration) as e:
+            logger.debug(f"{self} - failed to request potential leader {leader}:")
             return None
 
         finally:
             self.was_accepted_to_group.clear()
             self.current_leader = None
             if stream is not None:
-                await stream.aclose()
+                try:
+                    await stream.aclose()
+                except RuntimeError as e:
+                    logger.debug(e, exc_info=True)
 
     def get_request_expiration_time(self) -> float:
         """Returns the averager's current expiration time, which is used to send join requests to leaders"""
@@ -267,7 +276,11 @@ class Matchmaking:
                 self.current_followers[context.remote_id] = request
                 yield averaging_pb2.MessageFromLeader(code=averaging_pb2.ACCEPTED)
 
-                if len(self.current_followers) + 1 >= self.target_group_size and not self.assembled_group.done():
+                if (
+                    self.target_group_size is not None
+                    and len(self.current_followers) + 1 >= self.target_group_size
+                    and not self.assembled_group.done()
+                ):
                     # outcome 1: we have assembled a full group and are ready for allreduce
                     await self.leader_assemble_group()
 
@@ -353,7 +366,7 @@ class Matchmaking:
             )
         elif context.remote_id == self.peer_id or context.remote_id in self.current_followers:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.DUPLICATE_PEER_ID)
-        elif len(self.current_followers) + 1 >= self.target_group_size:
+        elif self.target_group_size is not None and len(self.current_followers) + 1 >= self.target_group_size:
             return averaging_pb2.MessageFromLeader(code=averaging_pb2.GROUP_IS_FULL)
         else:
             return None
@@ -372,7 +385,7 @@ class Matchmaking:
             for peer_id in ordered_peer_ids
         )
 
-        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers.")
+        logger.debug(f"{self.peer_id} - assembled group of {len(ordered_peer_ids)} peers")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), gathered)
         await self.group_key_manager.update_key_on_group_assembled(group_info, is_leader=True)
         self.assembled_group.set_result(group_info)
@@ -389,7 +402,7 @@ class Matchmaking:
         assert self.peer_id in ordered_peer_ids, "Leader sent us group_peer_ids that does not contain us!"
         assert len(ordered_peer_ids) == len(msg.gathered)
 
-        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}.")
+        logger.debug(f"{self.peer_id} - follower assembled group with leader {leader}")
         group_info = GroupInfo(group_id, tuple(ordered_peer_ids), tuple(msg.gathered))
         await self.group_key_manager.update_key_on_group_assembled(group_info)
         self.assembled_group.set_result(group_info)

+ 57 - 19
hivemind/averaging/partition.py

@@ -10,21 +10,24 @@ import torch
 
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
-from hivemind.utils.asyncio import amap_in_executor
+from hivemind.utils import amap_in_executor, as_aiter, get_logger
 
 T = TypeVar("T")
 DEFAULT_PART_SIZE_BYTES = 2 ** 19
+logger = get_logger(__name__)
 
 
 class TensorPartContainer:
     """
     Auxiliary data structure for averaging, responsible for splitting tensors into parts and reassembling them.
     The class is designed to avoid excessive memory allocation and run all heavy computation in background
+
     :param tensors: local tensors to be split and aggregated
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param compression: optionally compress tensors with this compression algorithm before sending them to peers
     :param part_size_bytes: greedily split tensors into parts of up to this many bytes (after compression)
     :param tensor_infos: CompressionInfo for each respective tensor; this determines how the tensor will be comressed
+    :param return_deltas: if True, output tensors are differences (aggregated tensor - local tensor)
     :param prefetch: when compressing, pre-compute this many compressed tensors in background
     """
 
@@ -35,7 +38,8 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
-        prefetch: int = 5,
+        return_deltas: bool = True,
+        prefetch: int = 1,
     ):
         if tensor_infos is None:
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
@@ -43,6 +47,8 @@ class TensorPartContainer:
         self.local_tensors, self.peer_fractions, self.group_size = tensors, peer_fractions, len(peer_fractions)
         self.compression, self.part_size_bytes, self.tensor_infos = compression, part_size_bytes, tensor_infos
         self.total_size = sum(tensor.numel() for tensor in tensors)
+        self.failed_size = 0
+        self.return_deltas = return_deltas
         self.prefetch = prefetch
 
         self._input_parts_by_peer = [deque() for _ in range(self.group_size)]
@@ -91,7 +97,6 @@ class TensorPartContainer:
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
         input_parts = tuple(part for part, compression in self._input_parts_by_peer[peer_index])
-        self._input_parts_by_peer[peer_index].clear()
         return input_parts
 
     @torch.no_grad()
@@ -99,13 +104,9 @@ class TensorPartContainer:
         """iterate serialized tensor parts for a peer at a given index. Run serialization in background."""
         assert not self._inputs_consumed_by_peer[peer_index], "input parts of a given peer are already deallocated."
         self._inputs_consumed_by_peer[peer_index] = True
-
-        async def _aiterate_parts():
-            for _ in range(self.num_parts_by_peer[peer_index]):
-                yield self._input_parts_by_peer[peer_index].popleft()
-
+        parts_aiter = as_aiter(*self._input_parts_by_peer[peer_index])
         async for serialized_part in amap_in_executor(
-            lambda x_and_info: self.compression.compress(*x_and_info), _aiterate_parts(), max_prefetch=self.prefetch
+            lambda x_and_info: self.compression.compress(*x_and_info), parts_aiter, max_prefetch=self.prefetch
         ):
             yield serialized_part
 
@@ -123,6 +124,16 @@ class TensorPartContainer:
         self._outputs_registered_by_peer[peer_index] += 1
         self._output_part_available[peer_index].set()
 
+    def register_failed_reducer(self, peer_index: int):
+        """
+        a given peer failed to aggregate a certain part, use our local part instead, keep track of failed parts
+        """
+        for part_index in range(self._outputs_registered_by_peer[peer_index], self.num_parts_by_peer[peer_index]):
+            part_and_info = self._input_parts_by_peer[peer_index][part_index]
+            part_result_or_delta = torch.zeros_like(part_and_info[0]) if self.return_deltas else part_and_info[0]
+            self.register_processed_part(peer_index, part_index, part_result_or_delta)
+            self.failed_size += part_result_or_delta.numel()
+
     async def iterate_output_tensors(self) -> AsyncIterable[torch.Tensor]:
         """iterate over the outputs of averaging (whether they are average, delta or other aggregation result)"""
         assert not self._outputs_consumed, "output tensors are already iterated and no longer available."
@@ -139,7 +150,7 @@ class TensorPartContainer:
                     self._output_part_available[peer_index].clear()
                     await self._output_part_available[peer_index].wait()
                     if self.finished.is_set():
-                        raise AllreduceException("All-reduce was terminated during iteration.")
+                        raise AllreduceException("All-reduce was terminated during iteration")
 
                 tensor_parts.append(self._output_parts_by_peer[peer_index].popleft())
                 num_parts_processed += 1
@@ -155,9 +166,11 @@ class TensorPartContainer:
         if not self.finished.is_set():
             for peer_index in range(self.group_size):
                 self._inputs_consumed_by_peer[peer_index] = True
+                self._output_part_available[peer_index].set()
                 self._input_parts_by_peer[peer_index].clear()
                 self._output_parts_by_peer[peer_index].clear()
-                self._output_part_available[peer_index].set()
+            if self.failed_size != 0:
+                logger.warning(f"Averaging: received {(1. - self.failed_size / self.total_size) * 100:.1f}% results")
             self._outputs_consumed = True
             self.finished.set()
 
@@ -178,11 +191,16 @@ class TensorPartReducer:
         self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.finished = asyncio.Event()
+
+        self.num_parts_received = [0 for _ in range(self.num_senders)]
+        self.sender_failed_after = [float("inf") for _ in range(self.num_senders)]
+        self.num_current_senders = self.num_senders
+
         self.reset_accumulators()
 
     def reset_accumulators(self):
         """(re)create averaging buffers for the next part in line, prepopulate with local tensor part"""
-        assert self.current_part_accumulated_from == self.num_senders or self.current_part_index == -1
+        assert self.current_part_accumulated_from == self.num_current_senders or self.current_part_index == -1
         if self.current_part_index >= self.num_parts - 1:
             self.finalize()
             return
@@ -190,6 +208,9 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
+        self.num_current_senders = sum(
+            self.current_part_index < failed_index for failed_index in self.sender_failed_after
+        )
         self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
         self.denominator = 0.0
 
@@ -199,6 +220,7 @@ class TensorPartReducer:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         assert 0 <= sender_index < self.num_senders, "invalid sender index"
         assert 0 <= part_index < self.num_parts, "invalid part index"
+        self.num_parts_received[sender_index] += 1
 
         while part_index > self.current_part_index:
             # wait for previous parts to finish processing ...
@@ -209,15 +231,25 @@ class TensorPartReducer:
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=weight)
-        self.current_part_accumulated_from += 1
-        self.denominator += weight
+        if part_index < self.sender_failed_after[sender_index]:
+            self.accumulator.add_(tensor_part, alpha=weight)
+            self.current_part_accumulated_from += 1
+            self.denominator += weight
+            self.check_current_part_finished()
+        return await current_part_future
 
-        assert self.current_part_accumulated_from <= self.num_senders
-        if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+    def on_sender_failed(self, sender_index: int):
+        """Exclude that sender's data for averaging any parts that it did not submit yet."""
+        self.sender_failed_after[sender_index] = self.num_parts_received[sender_index]
+        if self.current_part_index == self.num_parts_received[sender_index]:
+            self.num_current_senders -= 1
+            self.check_current_part_finished()
+
+    def check_current_part_finished(self):
+        assert self.current_part_accumulated_from <= self.num_current_senders
+        if self.current_part_accumulated_from == self.num_current_senders:
+            self.current_part_future.set_result(self.accumulator.div_(self.denominator))
             self.reset_accumulators()
-        return await current_part_future
 
     def finalize(self):
         if not self.finished.is_set():
@@ -226,6 +258,12 @@ class TensorPartReducer:
                 del self.accumulator
             self.finished.set()
 
+            if self.num_parts != 0 and self.num_senders != 0:
+                parts_expected = self.num_parts * self.num_senders
+                parts_received = sum(self.num_parts_received)
+                if parts_expected != parts_received:
+                    logger.info(f"Reducer: received {parts_received / parts_expected * 100:.1f}% of input tensors")
+
     def __del__(self):
         self.finalize()
 

+ 3 - 0
hivemind/compression/base.py

@@ -65,6 +65,9 @@ class CompressionBase(ABC):
         """Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
         ...
 
+    def __repr__(self):
+        return f"hivemind.{self.__class__.__name__}()"
+
 
 class NoCompression(CompressionBase):
     """A dummy compression strategy that preserves the original tensor as is."""

+ 1 - 1
hivemind/dht/dht.py

@@ -151,7 +151,7 @@ class DHT(mp.Process):
             self._outer_pipe.send(("_shutdown", [], {}))
             self.join(self.shutdown_timeout)
             if self.is_alive():
-                logger.warning("DHT did not shut down within the grace period; terminating it the hard way.")
+                logger.warning("DHT did not shut down within the grace period; terminating it the hard way")
                 self.terminate()
 
     async def _shutdown(self):

+ 1 - 1
hivemind/dht/node.py

@@ -717,7 +717,7 @@ class DHTNode:
         """Add key to a refresh queue, refresh at :refresh_time: or later"""
         if self.cache_refresh_task is None or self.cache_refresh_task.done() or self.cache_refresh_task.cancelled():
             self.cache_refresh_task = asyncio.create_task(self._refresh_stale_cache_entries())
-            logger.debug("Spawned cache refresh task.")
+            logger.debug("Spawned cache refresh task")
         earliest_key, earliest_item = self.cache_refresh_queue.top()
         if earliest_item is None or refresh_time < earliest_item.expiration_time:
             self.cache_refresh_evt.set()  # if we new element is now earliest, notify the cache queue

+ 1 - 1
hivemind/dht/routing.py

@@ -217,7 +217,7 @@ class KBucket:
 
     def __delitem__(self, node_id: DHTID):
         if not (node_id in self.nodes_to_peer_id or node_id in self.replacement_nodes):
-            raise KeyError(f"KBucket does not contain node id={node_id}.")
+            raise KeyError(f"KBucket does not contain node id={node_id}")
 
         if node_id in self.replacement_nodes:
             del self.replacement_nodes[node_id]

+ 1 - 1
hivemind/hivemind_cli/run_server.py

@@ -29,7 +29,7 @@ def main():
                         help="specify the exact list of expert uids to create. Use either this or num_experts"
                              " and expert_pattern, not both")
     parser.add_argument('--expert_cls', type=str, default='ffn', required=False,
-                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'.")
+                        help="expert type from test_utils.layers, e.g. 'ffn', 'transformer', 'det_dropout' or 'nop'")
     parser.add_argument('--hidden_dim', type=int, default=1024, required=False, help='main dimension for expert_cls')
 
     parser.add_argument('--num_handlers', type=int, default=None, required=False,

+ 2 - 2
hivemind/moe/client/moe.py

@@ -238,7 +238,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, k_min, forward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(responded_inds) < k_min:
-            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout.")
+            raise TimeoutError(f"Forward pass: less than {k_min} responded within timeout")
 
         if not isinstance(info["outputs_schema"], tuple):
             outputs_schema = (info["outputs_schema"],)
@@ -330,7 +330,7 @@ class _RemoteCallMany(torch.autograd.Function):
             pending_tasks, num_samples, backward_k_min, backward_timeout, timeout_after_k_min, detect_anomalies
         )
         if len(survivor_inds) < backward_k_min:
-            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout.")
+            raise TimeoutError(f"Backward pass: less than {backward_k_min} experts responded within timeout")
 
         # assemble responses
         batch_inds, expert_inds = map(

+ 2 - 2
hivemind/moe/server/__init__.py

@@ -333,7 +333,7 @@ def background_server(*args, shutdown_timeout=5, **kwargs) -> Tuple[hivemind.End
         if runner.is_alive():
             logger.info("Server failed to shutdown gracefully, terminating it the hard way...")
             runner.kill()
-            logger.info("Server terminated.")
+            logger.info("Server terminated")
 
 
 def _server_runner(pipe, *args, **kwargs):
@@ -353,4 +353,4 @@ def _server_runner(pipe, *args, **kwargs):
         logger.info("Shutting down server...")
         server.shutdown()
         server.join()
-        logger.info("Server shut down.")
+        logger.info("Server shut down")

+ 2 - 1
hivemind/optim/__init__.py

@@ -1,6 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager

+ 15 - 7
hivemind/optim/collaborative.py

@@ -57,6 +57,10 @@ class TrainingProgressSchema(BaseModel):
 
 class CollaborativeOptimizer(DecentralizedOptimizerBase):
     """
+    :note: **For new projects please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
+      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and a many advanced ones.
+      CollaborativeOptimizer will still be supported for a while, but it will be deprecated eventually.
+
     An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers
 
     These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
@@ -229,7 +233,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
         """
         if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
-            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler).")
+            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
         if self.batch_size_per_step is None:
             if batch_size is None:
                 raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
@@ -238,12 +242,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
         if not self.is_synchronized and not self.is_within_tolerance:
-            logger.log(self.status_loglevel, "Peer is out of sync.")
+            logger.log(self.status_loglevel, "Peer is out of sync")
             self.load_state_from_peers()
             return
         elif not self.is_synchronized and self.is_within_tolerance:
             self.averager.local_step = self.collaboration_state.optimizer_step
-            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
+            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
 
         if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
@@ -300,12 +304,12 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                                 logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
 
                 except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             else:
                 logger.log(
                     self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s).",
+                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
                 )
 
             if grad_scaler is not None:
@@ -320,10 +324,14 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = current_step + 1
             self.collaboration_state_updated.set()
             self.update_scheduler()
+
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                     assert grad_scaler.update()
 
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = self.local_step
+
         logger.log(self.status_loglevel, f"Optimizer step: done!")
 
         return group_info
@@ -357,7 +365,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
                         else:
                             logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
             except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}.")
+                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
 
             self.collaboration_state.register_step(current_step + 1)
             self.averager.local_step = current_step + 1
@@ -544,7 +552,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
         )
-        logger.debug(f"{self.__class__.__name__} is shut down.")
+        logger.debug(f"{self.__class__.__name__} is shut down")
 
     def __del__(self):
         self.shutdown()

+ 29 - 22
hivemind/optim/experimental/grad_averager.py

@@ -6,7 +6,7 @@ import torch
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_logger
+from hivemind.utils import DHTExpiration, get_dht_time, get_logger
 
 logger = get_logger(__name__)
 
@@ -80,7 +80,7 @@ class GradientAverager(DecentralizedAverager):
         if reuse_grad_buffers and accumulate_grads_on is not None:
             logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
         client_mode = client_mode if client_mode is not None else dht.client_mode
-        self._parameters = tuple(parameters)
+        self.parameters = tuple(parameters)
         self.reuse_grad_buffers = reuse_grad_buffers
         self.warn = warn
         self.local_samples_accumulated = 0
@@ -102,7 +102,7 @@ class GradientAverager(DecentralizedAverager):
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
         """gradient buffers associated with parameters"""
-        for param in self._parameters:
+        for param in self.parameters:
             if param.grad is None:
                 param.grad = torch.zeros_like(param)
             yield param.grad
@@ -119,7 +119,7 @@ class GradientAverager(DecentralizedAverager):
         if self._accumulators_used_in_step and self.warn:
             logger.warning(
                 "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
-                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)."
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)"
             )
             self._accumulators_used_in_step = False  # warn once per round
         if self._anchor_batch_size is None:
@@ -152,6 +152,7 @@ class GradientAverager(DecentralizedAverager):
         weight: Optional[float] = None,
         reset_accumulators: bool = True,
         control: Optional[StepControl] = None,
+        timeout: Optional[float] = None,
         wait: bool = True,
         **kwargs,
     ):
@@ -161,33 +162,34 @@ class GradientAverager(DecentralizedAverager):
         :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
         :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
         :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param timeout: if specified, await for averaging round for at most this number of seconds (if wait=True)
         :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
         """
         if control is None:
-            control = self.schedule_step(**kwargs)
+            control = self.schedule_step(timeout=timeout, **kwargs)
         elif len(kwargs) > 0:
-            RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
-        assert not control.triggered, f"This {type(control)} instance was already used."
-        self._load_accumulators_into_averager_()
+            raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect")
+        assert not control.triggered, f"This {type(control)} instance was already used"
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used. "
+                "This may be a sign of incorrect optimizer behavior"
+            )
+
+        self.load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
         self._new_averaged_grads = True
 
         control.weight = self.local_samples_accumulated if weight is None else weight
         if reset_accumulators:
             self.reset_accumulated_grads_()
-
         control.allow_allreduce()
-        return control.result() if wait else control
+
+        return control.result(timeout) if wait else control
 
     @torch.no_grad()
-    def _load_accumulators_into_averager_(self):
+    def load_accumulators_into_averager_(self):
         """load locally accumulated gradients into the averager for aggregation"""
-        if self._new_averaged_grads and self.warn:
-            logger.warning(
-                "[warn=True] Starting new averaging round, but previous round results were not used."
-                "This may be a sign of incorrect optimizer behavior."
-            )
-            self._new_averaged_grads = False  # warn once per round
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
@@ -206,14 +208,19 @@ class GradientAverager(DecentralizedAverager):
     @contextlib.contextmanager
     @torch.no_grad()
     def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
             try:
-                assert len(averaged_grads) == len(self._parameters)
-                old_grads = [param.grad for param in self._parameters]
-                for param, new_grad in zip(self._parameters, averaged_grads):
+                old_grads = [param.grad for param in self.parameters]
+                for param, new_grad in zip(self.parameters, averaged_grads):
                     param.grad = new_grad
-                yield
+                yield averaged_grads
             finally:
-                for param, old_grad in zip(self._parameters, old_grads):
+                for param, old_grad in zip(self.parameters, old_grads):
                     param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 758 - 0
hivemind/optim/experimental/optimizer.py

@@ -0,0 +1,758 @@
+from __future__ import annotations
+
+import logging
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.compression import CompressionBase, NoCompression
+from hivemind.dht import DHT
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
+from hivemind.optim.experimental.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size.
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+    or even fully asynchronous (use_local_updates=True).
+
+    :example: The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096,
+    >>>                          params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params))
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    By default, peers will perform the following steps:
+
+     * accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     * after peers collectively accumulate target_batch_size, average gradients with peers and perform optimizer step;
+     * if your peer lags behind the rest of the swarm, it will download parameters and optimizer state from others;
+
+    Unlike regular training, your device may join midway through training, when other peers already made some progress.
+    For this reason, any learning rate schedulers, curriculum and other **time-dependent features should be based on**
+    ``optimizer.local_epoch`` (and not the number ot calls to opt.step). Otherwise, peers that joined training late
+    may end up having different learning rates. To do so automatically, specify ``scheduler=...`` parameter below.
+
+    :What is an epoch?: Optimizer uses the term ``epoch`` to describe intervals between synchronizations. One epoch
+      coresponds to processing certain number of training samples (``target_batch_size``) in total across all peers.
+      Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**
+      At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update,
+      updating the learning rate scheduler or simply averaging parameters (if using local updates).
+      The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters.
+      For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.
+
+    :Configuration guide: This guide will help you set up your first collaborative training run. It covers the most
+      important basic options, but ignores features that require significant changes to the training code.
+
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True)
+    >>> opt = hivemind.Optimizer(
+    >>>    dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,
+    >>>    # ^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);
+    >>>    # It is recommended to train with very large batch sizes to reduce the % of time spent on communication.
+    >>>
+    >>>    params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),
+    >>>    # tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962
+    >>>    scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),
+    >>>    # scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size
+    >>>
+    >>>    offload_optimizer=True,  # saves GPU memory, but increases RAM usage; Generally a good practice to use this.
+    >>>    delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;
+    >>>    # setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)
+    >>>
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>    # ^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>    # See hivemind/examples/albert for an working example of mixed 8/16-bit compression.
+    >>>
+    >>>    matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers
+    >>>    averaging_timeout=60.0,  # around of 2x the actual time it takes to run all-reduce
+    >>>    verbose=True  # periodically report the training progress to the console (e.g. "Averaged with N peers")
+    >>> )  # and you're done!
+
+
+    :param dht: a running hivemind.DHT instance connected to other peers.
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch.
+      The actual batch may be *slightly* larger due asynchrony (e.g. peers submit more gradients in the last second).
+    :param batch_size_per_step: you should accumulate gradients over this many samples between calls to optimizer.step.
+
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params).
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer.
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging require
+      and require the callable and will not work if hivemind.optimizer is created with a pre-existing PyTorch Optimizer.
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
+
+    :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds.
+      Increase if you see "averaged gradients with N peers" where N is below 0.9x the real siee on >=25% of epochs.
+      When training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes.
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled automatically.
+      Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time.
+      Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors.
+    :param allreduce_timeout: timeout for a single attempt to run all-reduce, default: equal to averaging_timeout.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers.
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
+    :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
+    :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
+    :param verbose: if True, report internal events such as accumilating gradients and running background tasks
+
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        run_id: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 60.0,
+        allreduce_timeout: Optional[float] = None,
+        next_chunk_timeout: Optional[float] = None,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        offload_optimizer: Optional[bool] = None,
+        delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
+        client_mode: bool = None,
+        auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
+        shutdown_timeout: float = 5,
+        verbose: bool = False,
+    ):
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
+        next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+
+        self.averaging_timeout, self.allreduce_timeout = averaging_timeout, allreduce_timeout
+        self.load_state_timeout, self.shutdown_timeout = load_state_timeout, shutdown_timeout
+        self.next_chunk_timeout = next_chunk_timeout
+
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
+
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
+        )
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
+        self._schema_hash = self._compute_schema_hash()
+        self._parent_pid = os.getpid()
+
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
+        self._step_supports_amp_scaling = reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
+            status_loglevel=self.status_loglevel,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.allreduce_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            next_chunk_timeout=self.next_chunk_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.run_id,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        """
+        This worker's current epoch, kept synchronized with peers. If peer's local_epoch lags behind others, it will
+        automatically re-synchronize by downloading state from another peer.
+        An epoch corresponds to accumulating target_batch_size across all active devices.
+        """
+        return self.state_averager.local_epoch
+
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
+
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
+
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[GradScaler] = None,
+    ):
+        """
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
+
+        :param closure: A closure that reevaluates the model and returns the loss.
+        :param batch_size: optional override for batch_size_per_step from init.
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler.
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        if not self.auxiliary and self.should_load_state_from_peers():
+            logger.log(self.status_loglevel, "Peer is out of sync")
+            self.load_state_from_peers()
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
+
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
+
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
+
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
+
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
+
+        return loss
+
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
+
+        with self.tracker.pause_updates():
+            wait_for_trigger = None
+
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                if self.delay_optimizer_step:
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    pass  # failed to start gradient averaging due to an internal error
+                elif self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
+
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
+
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
+                    self.scheduled_state = None
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
+                grad_scaler=grad_scaler,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+            )
+
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+            logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}")
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        began_averaging_gradients = False
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
+            self.scheduled_grads = None
+
+        elif self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            logger.log(self.status_loglevel, f"Tagging along for a pre-scheduled gradient averaging round")
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f} sec")
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
+        if self.delay_before_state_averaging.num_updates == 0:
+            return  # not enough data to accurately pre-schedule
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = estimated_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f} sec")
+                self.scheduled_state = self.state_averager.schedule_step(
+                    gather=next_epoch, timeout=self.averaging_timeout
+                )
+
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            logger.log(self.status_loglevel, f"Proceeding with local gradients")
+            self.grad_averager.load_accumulators_into_averager_()
+            self._load_averaged_gradients_into_optimizer_()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+
+        self.grad_averager.notify_used_averaged_gradients()
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally"
+            )
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
+        # note: we tag along for the next all-reduce because the run may have already started and cancelling it
+        # will cause peers to restart matchmaking and may  stall the entire collaboration for a few seconds.
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self._tag_along_with_zero_weight(self.scheduled_grads)
+            self.scheduled_grads = None
+        self.state_averager.step(wait_for_delayed_updates=True)
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation. "
+            f"Please provide all parameter groups at init"
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
+
+    def _tag_along_with_zero_weight(self, control: StepControl):
+        """Wait for a running averaging round to finish with zero weight."""
+        if not control.triggered:
+            control.weight = 0
+            control.allow_allreduce()
+        if not control.done():
+            try:
+                control.result(self.averaging_timeout)
+            except BaseException as e:
+                logger.exception(e)
+                if not control.done():
+                    control.cancel()
+
+    def shutdown(self):
+        logger.log(self.status_loglevel, "Sending goodbye to peers...")
+        self.tracker.shutdown(self.shutdown_timeout)
+        self.state_averager.step(wait_for_delayed_updates=True)
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                else:
+                    self._tag_along_with_zero_weight(scheduled_round)
+
+        logger.log(self.status_loglevel, "Shutting down averagers...")
+        self.state_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
+        logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down")
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 358 - 0
hivemind/optim/experimental/progress_tracker.py

@@ -0,0 +1,358 @@
+import asyncio
+import contextlib
+import logging
+import threading
+from dataclasses import dataclass
+from typing import Dict, Optional
+
+import numpy as np
+from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
+
+from hivemind.dht import DHT
+from hivemind.dht.schema import BytesWithPublicKey, RSASignatureValidator, SchemaValidator
+from hivemind.utils import DHTExpiration, ValueWithExpiration, enter_asynchronously, get_dht_time, get_logger
+from hivemind.utils.crypto import RSAPrivateKey
+from hivemind.utils.performance_ema import PerformanceEMA
+
+logger = get_logger(__name__)
+
+
+@dataclass(frozen=False)
+class GlobalTrainingProgress:
+    epoch: int
+    samples_accumulated: int
+    target_batch_size: int
+    num_peers: int
+    num_clients: int
+    eta_next_epoch: float
+    next_fetch_time: float
+
+
+class LocalTrainingProgress(BaseModel):
+    peer_id: bytes
+    epoch: conint(ge=0, strict=True)
+    samples_accumulated: conint(ge=0, strict=True)
+    samples_per_second: confloat(ge=0.0, strict=True)
+    time: StrictFloat
+    client_mode: StrictBool
+
+
+class TrainingProgressSchema(BaseModel):
+    progress: Dict[BytesWithPublicKey, Optional[LocalTrainingProgress]]
+
+
+class ProgressTracker(threading.Thread):
+    """
+    Auxiliary class that keeps track of local & global training progress, measured in epochs.
+    An epoch can be incremented after collaboration accumulates a said number of gradients (target_batch_size).
+    Similarly to pytorch LR scheduler, epoch can be incremented on a single optimizer update or many local updates.
+
+    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
+    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
+    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
+    :param expected_drift_peers: assume that this many new peers can join between epochs
+    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between epochs
+    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
+      refresh the collaboration-wide statistics (to avoid missing the moment when peers transition to the next epoch)
+    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (samples per second)
+    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
+
+    Example:
+
+    >>> tracker = ProgressTracker(hivemind.DHT(...), prefix="my_experiment_with_several_peers", target_batch_size=100)
+    >>> local_epoch, local_samples = 0, 0
+    >>> while True:
+    >>>     accumulate_gradients(batch_size=32)
+    >>>     local_samples += 32
+    >>>     tracker.report_local_progress(local_epoch, local_samples)
+    >>>     if local_epoch < tracker.global_progress.epoch:
+    >>>         download_state_from_peers()  # if peer is out of sync, synchronize it with the swarm
+    >>>     if tracker.accumulated_enough_samples:
+    >>>         with tracker.pause_updates():
+    >>>             aggregate_gradients_with_peers()
+    >>>             update_model_parameters()
+    >>>             local_epoch = tracker.update_epoch(local_epoch + 1)
+    >>>             local_samples = 0
+    """
+
+    def __init__(
+        self,
+        dht: DHT,
+        prefix: str,
+        target_batch_size: int,
+        *,
+        client_mode: Optional[bool] = None,
+        min_refresh_period: float = 0.5,
+        max_refresh_period: float = 10,
+        default_refresh_period: float = 3,
+        expected_drift_peers: float = 3,
+        expected_drift_rate: float = 0.2,
+        performance_ema_alpha: float = 0.1,
+        metadata_expiration: float = 60.0,
+        status_loglevel: int = logging.DEBUG,
+        private_key: Optional[RSAPrivateKey] = None,
+        daemon: bool = True,
+        start: bool,
+    ):
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self.dht, self.prefix, self.client_mode = dht, prefix, client_mode
+        self.training_progress_key = f"{self.prefix}_progress"
+        self.target_batch_size = target_batch_size
+        self.min_refresh_period, self.max_refresh_period = min_refresh_period, max_refresh_period
+        self.default_refresh_period = default_refresh_period
+        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
+        self.status_loglevel = status_loglevel
+        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
+        self.metadata_expiration = metadata_expiration
+
+        signature_validator = RSASignatureValidator(private_key)
+        self._local_public_key = signature_validator.local_public_key
+        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
+
+        # report the collaboration progress periodically or in background
+        self.local_progress = self._get_local_progress(local_epoch=0, samples_accumulated=0)
+        metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
+        self.global_progress = self._parse_swarm_progress_data(metadata)
+        self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
+        self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
+        super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
+        if start:
+            self.start()
+
+    @property
+    def global_epoch(self) -> int:
+        return self.global_progress.epoch
+
+    @property
+    def ready_to_update_epoch(self) -> bool:
+        """Whether or not this peer can increment epoch right away."""
+        return (
+            self.global_epoch > self.local_progress.epoch
+            or self.global_progress.samples_accumulated >= self.target_batch_size
+            or get_dht_time() >= self.global_progress.eta_next_epoch
+        )
+
+    @property
+    def estimated_next_update_time(self) -> DHTExpiration:
+        """Estimate (absolute) time when this peer should increment epoch"""
+        if self.ready_to_update_epoch:
+            return get_dht_time()
+        return self.global_progress.eta_next_epoch
+
+    def _get_local_progress(self, local_epoch: int, samples_accumulated: int):
+        return LocalTrainingProgress(
+            peer_id=self.dht.peer_id.to_bytes(),
+            epoch=local_epoch,
+            samples_accumulated=samples_accumulated,
+            samples_per_second=self.performance_ema.samples_per_second,
+            time=get_dht_time(),
+            client_mode=self.client_mode,
+        )
+
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
+        """Update the number of locally accumulated samples and notify to other peers about this."""
+        extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
+        if extra_samples > 0:
+            self.performance_ema.update(task_size=extra_samples)
+            logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
+        else:
+            logger.debug("Resetting performance timestamp to current time (progress was reset)")
+            self.performance_ema.reset_timer()
+
+        self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
+        self.should_report_progress.set()
+
+    @contextlib.contextmanager
+    def pause_updates(self):
+        """Temporarily stop progress tracker from updating global training state"""
+        with self.lock_global_progress, self.performance_ema.pause():
+            yield
+
+    def update_epoch(self, new_epoch: Optional[int] = None) -> int:
+        """Update the local epoch, reset the number of sample accumulated, reset local progress, return new epoch"""
+        assert self.lock_global_progress.locked(), "ProgressTracker must be paused when incrementing epoch"
+        if new_epoch is None:
+            new_epoch = self.local_progress.epoch + 1
+        if new_epoch > self.global_progress.epoch:
+            self.global_progress.epoch = new_epoch
+            self.global_progress.samples_accumulated = 0
+            self.global_progress.eta_next_epoch = float("inf")
+        self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.fetched_global_progress_this_epoch.clear()
+        return new_epoch
+
+    def run(self):
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+        loop.run_until_complete(asyncio.gather(self._progress_reporter(), self._progress_fetcher()))
+        self.shutdown_complete.set()
+
+    async def _progress_reporter(self):
+        """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
+        last_report_time = -float("inf")
+        store_task = None
+        try:
+            while not self.shutdown_triggered.is_set():
+                wait_timeout = max(0.0, last_report_time - get_dht_time() + self.metadata_expiration / 2)
+                logger.debug(f"Will report progress again in {wait_timeout} seconds or on user command")
+                await asyncio.get_event_loop().run_in_executor(None, self.should_report_progress.wait, wait_timeout)
+                if self.should_report_progress.is_set():
+                    logger.debug(f"Progress update triggered by report_local_progress")
+                    self.should_report_progress.clear()
+                else:
+                    logger.debug(f"Progress update triggered by metadata_expiration")
+
+                local_progress = self.local_progress
+                last_report_time = get_dht_time()
+
+                store_task = asyncio.create_task(
+                    asyncio.wait_for(
+                        self.dht.store(
+                            key=self.training_progress_key,
+                            subkey=self._local_public_key,
+                            value=local_progress.dict(),
+                            expiration_time=last_report_time + self.metadata_expiration,
+                            return_future=True,
+                        ),
+                        timeout=self.metadata_expiration,
+                    )
+                )
+        finally:
+            logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}")
+            if store_task is not None:
+                store_task.cancel()
+
+    async def _progress_fetcher(self):
+        """
+        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
+        """
+        loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(
+            asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
+        )
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.create_task(
+                asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            )
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
+        try:
+            while not self.shutdown_triggered.is_set():
+                time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
+                state_updated_externally = await loop.run_in_executor(
+                    None, self.global_state_updated.wait, time_to_next_update
+                )
+                if state_updated_externally:
+                    self.global_state_updated.clear()
+                    continue
+
+                async with enter_asynchronously(self.lock_global_progress):
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
+                    self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.fetched_global_progress_this_epoch.set()
+
+        finally:
+            logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}")
+
+    def _parse_swarm_progress_data(self, metadata: TrainingProgressSchema) -> GlobalTrainingProgress:
+        """Read performance statistics reported by peers, estimate progress towards next batch"""
+        current_time = get_dht_time()
+
+        if not isinstance(metadata, dict) or len(metadata) == 0:
+            logger.log(self.status_loglevel, f"Found no active peers: {metadata}")
+            samples_remaining_to_next_epoch = max(0, self.target_batch_size - self.local_progress.samples_accumulated)
+            local_eta_next_epoch = samples_remaining_to_next_epoch / self.performance_ema.samples_per_second
+
+            return GlobalTrainingProgress(
+                self.local_progress.epoch,
+                self.local_progress.samples_accumulated,
+                self.target_batch_size,
+                num_peers=0,
+                num_clients=0,
+                eta_next_epoch=current_time + local_eta_next_epoch,
+                next_fetch_time=current_time + self.default_refresh_period,
+            )
+
+        valid_peer_entries = [
+            LocalTrainingProgress.parse_obj(peer_state.value)
+            for peer_state in metadata.values()
+            if peer_state.value is not None
+        ]
+
+        num_peers = len(valid_peer_entries)
+        num_clients = sum(peer.client_mode for peer in valid_peer_entries)
+
+        global_epoch = self.local_progress.epoch
+        for peer in valid_peer_entries:
+            if not peer.client_mode:
+                global_epoch = max(global_epoch, peer.epoch)
+
+        total_samples_accumulated = estimated_current_samples = 0
+        total_samples_per_second = self.performance_ema.eps
+
+        for peer in valid_peer_entries:
+            total_samples_per_second += peer.samples_per_second
+            if peer.epoch == global_epoch:
+                total_samples_accumulated += peer.samples_accumulated
+                estimated_current_samples += (
+                    peer.samples_accumulated + max(0.0, current_time - peer.time) * peer.samples_per_second
+                )
+            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
+            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
+
+        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
+        estimated_time_to_next_epoch = max(0, estimated_samples_remaining) / total_samples_per_second
+
+        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
+        time_to_next_fetch = float(
+            np.clip(
+                a=estimated_time_to_next_epoch * num_peers / expected_max_peers,
+                a_min=self.min_refresh_period,
+                a_max=self.max_refresh_period,
+            )
+        )
+        logger.log(
+            self.status_loglevel,
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
+            f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
+        )
+        return GlobalTrainingProgress(
+            global_epoch,
+            total_samples_accumulated,
+            target_batch_size=self.target_batch_size,
+            num_peers=num_peers,
+            num_clients=num_clients,
+            eta_next_epoch=current_time + estimated_time_to_next_epoch,
+            next_fetch_time=current_time + time_to_next_fetch,
+        )
+
+    def shutdown(self, timeout: Optional[float] = None):
+        """Permanently disable all tracking activity"""
+        self.shutdown_triggered.set()
+        self.should_report_progress.set()
+        self.global_state_updated.set()
+        self.shutdown_complete.wait(timeout)
+        self.dht.store(
+            self.training_progress_key,
+            subkey=self._local_public_key,
+            value=None,
+            expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True,
+        )
+
+    def __del__(self):
+        if self.is_alive():
+            self.shutdown()

+ 260 - 115
hivemind/optim/experimental/state_averager.py

@@ -1,18 +1,20 @@
 """ An extension of averager that supports common optimization use cases. """
 import logging
-from asyncio import Future
+import threading
+import time
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
-from threading import Event
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
 
 import hivemind
-from hivemind import nested_compare
 from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
-from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 
@@ -36,7 +38,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
     Example:
 
-    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
     >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
     >>> avgr.load_state_from_peers()
     >>> for i, batch in enumerate(training_dataloader):
@@ -49,7 +51,7 @@ class TrainingStateAverager(DecentralizedAverager):
       TrainingStateAverager.step(..., optimizer_step=True)
 
     :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
-    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
     :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
     :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
     :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
@@ -60,8 +62,11 @@ class TrainingStateAverager(DecentralizedAverager):
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
-    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param parameter_names: optionally provide parameter names in the same order as in params
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
@@ -73,12 +78,14 @@ class TrainingStateAverager(DecentralizedAverager):
         *,
         dht: hivemind.DHT,
         optimizer: Union[TorchOptimizer, OptimizerFactory],
-        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        params: Optional[Union[Parameters, ParamGroups]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         initialize_optimizer: Optional[bool] = None,
         offload_optimizer: bool = False,
         custom_gradients: bool = False,
-        reuse_tensors: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
         sync_epoch_when_averaging: bool = False,
         parameter_names: Optional[Sequence[str]] = None,
         average_opt_statistics: Sequence[str] = (),
@@ -88,20 +95,22 @@ class TrainingStateAverager(DecentralizedAverager):
     ):
         average_opt_statistics = tuple(average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
-        if offload_optimizer and reuse_tensors:
-            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
         if custom_gradients and not offload_optimizer:
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
 
-        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
 
         self.status_loglevel = status_loglevel
-        self.reuse_tensors = reuse_tensors
-        self.offload_optimizer = offload_optimizer
-        self.custom_gradients = custom_gradients
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
-        self._main_parameters, self._parameter_names = main_parameters, parameter_names
-        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self.main_parameters, self.parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
         )
@@ -109,11 +118,13 @@ class TrainingStateAverager(DecentralizedAverager):
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.local_epoch = 0
 
-        self.step_executor = ThreadPoolExecutor(max_workers=1)
-        self.finished_optimizer_step = Event()
-        self.finished_averaging_round = Event()
-        self.pending_update = Future()
-        self.pending_update.set_result(None)
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
+        self.pending_updates = set()
 
         super().__init__(
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
@@ -143,10 +154,15 @@ class TrainingStateAverager(DecentralizedAverager):
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         return param_groups, parameters, parameter_names
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors:
-            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+        if self.reuse_tensors and not force_copy:
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU")
             if not source_tensor.is_shared():
                 source_tensor.share_memory_()
             return source_tensor
@@ -173,19 +189,26 @@ class TrainingStateAverager(DecentralizedAverager):
         # create optimizer
         if optimizer_is_factory:
             if self.offload_optimizer:
-                for param in self._averaged_parameters:
-                    if param.grad is None:
-                        param.grad = torch.zeros_like(param)
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
 
                 next_index = 0
                 param_groups_for_optimizer = []
                 for param_group in param_groups:
                     num_params = len(param_group["params"])
-                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     next_index += num_params
-                assert next_index == len(self._averaged_parameters)
+                assert next_index == len(parameters_for_optimizer)
 
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
             else:
                 param_groups_for_optimizer = param_groups
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
@@ -197,8 +220,8 @@ class TrainingStateAverager(DecentralizedAverager):
             initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
             logger.log(
                 self.status_loglevel,
-                "Initializing optimizer manually since it has no tensors in state dict"
-                "To override this, please provide initialize_optimizer=False",
+                "Initializing optimizer manually since it has no tensors in state dict. "
+                "To override this, provide initialize_optimizer=False",
             )
 
         if initialize_optimizer:
@@ -213,7 +236,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
         # verify optimizer and scheduler
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
-        if self.offload_optimizer or self.reuse_tensors:
+        if self.reuse_tensors:
             for param_group in optimizer.param_groups:
                 for param in param_group["params"]:
                     assert param.is_shared()
@@ -250,19 +273,19 @@ class TrainingStateAverager(DecentralizedAverager):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
             assert local_tensor.shape == averaged_tensor.shape
             if averaged_tensor.grad is not None:
-                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
 
         return averaged_tensors
 
     def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
         """Get CompressionInfo for each state tensor, accounting for its role and specification"""
         tensor_infos = []
-        for param, param_name in zip(self._main_parameters, self._parameter_names):
+        for param, param_name in zip(self.main_parameters, self.parameter_names):
             tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
         for stats_name in self.opt_keys_for_averaging:
             opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(opt_parameters) == len(self._parameter_names)
-            for param, param_name in zip(opt_parameters, self._parameter_names):
+            assert len(opt_parameters) == len(self.parameter_names)
+            for param, param_name in zip(opt_parameters, self.parameter_names):
                 tensor_infos.append(
                     CompressionInfo.from_tensor(
                         self.optimizer.state[param][stats_name],
@@ -274,9 +297,22 @@ class TrainingStateAverager(DecentralizedAverager):
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
         return tuple(tensor_infos)
 
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
     def step(
         self,
-        wait_for_delayed_update: bool = None,
+        wait_for_delayed_updates: bool = None,
         apply_delayed_updates: bool = True,
         increment_epoch: bool = False,
         optimizer_step: bool = False,
@@ -284,131 +320,217 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
-        averaging_kwargs: Optional[Dict[str, Any]] = None,
+        averaging_control: Optional[StepControl] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
+        grad_scaler: Optional[GradScaler] = None,
+        averaging_opts: Optional[Dict[str, Any]] = None,
     ):
         """
         Perform one or several possible actions, depending on the specified keyword args.
         The actions will be performed in the same order as specified below:
 
-        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
-        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
+        :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
-        if wait_for_delayed_update is None:
-            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
-        if optimizer_step or averaging_round or zero_grad:
-            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
-        if averaging_kwargs and not averaging_round:
-            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        if averaging_opts and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
+        if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
+            if not (self.reuse_tensors or self.custom_gradients):
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details"
+                )
         output = None
 
-        if wait_for_delayed_update:
-            if not self.pending_update.done():
-                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                output = self.pending_update.result()
-
-        if self.pending_update.done() and self.pending_update.exception():
-            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    timeout = (averaging_opts or {}).get("averaging_timeout", self._allreduce_timeout)
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result(timeout)
+                except BaseException:
+                    # exception will be reported below
+                    if not pending_update.done():
+                        pending_update.cancel()
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.cancelled() or finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed: {finished_update}")
 
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
-                logger.log(self.status_loglevel, "Received results from background averaging round")
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
+                logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
 
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                    self._apply_optimizer_parameters_()
+                logger.debug("Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
 
         if increment_epoch:
             self.local_epoch += 1
-            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
-            self._update_scheduler()
 
         if optimizer_step or zero_grad or averaging_round:
-            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
-
             if self.offload_optimizer and not self.custom_gradients:
                 self._load_local_grads_into_optimizer_()
 
-            self.pending_update = self.step_executor.submit(
+            pending_update = self.step_executor.submit(
                 self._do,
+                wait_for_trigger,
                 optimizer_step,
                 zero_grad,
                 averaging_round,
-                **averaging_kwargs or {},
+                averaging_control,
+                grad_scaler,
+                **averaging_opts or {},
             )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
 
-            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+            if should_await_optimizer:
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
-                if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Finished optimizer step")
+                if self.offload_optimizer and not should_await_averaging:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Finished optimizer step")
 
-            if averaging_round and not delay_averaging:
+            if should_await_averaging:
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.clear()
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished averaging round")
 
-            if not delay_averaging:
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
                 try:
-                    output = self.pending_update.result()
+                    output = pending_update.result()
                 finally:
-                    self.finished_averaging_round.clear()
-                    self.finished_optimizer_step.clear()
+                    self.pending_updates.remove(pending_update)
+
         return output
 
-    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+    def _do(
+        self,
+        wait_for_trigger: Optional[Callable[[], Any]],
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        averaging_control: Optional[StepControl],
+        grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
+        **kwargs,
+    ):
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         """
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
+
+        start_time = time.perf_counter()
+        began_running = False
+
         try:
-            if optimizer_step:
-                logger.log(self.status_loglevel, f"Running optimizer step")
-                self.optimizer.step()
-            if zero_grad:
-                logger.log(self.status_loglevel, f"Running zero grad")
-                self.optimizer.zero_grad()
-                if self.offload_optimizer:
-                    for parameter in self._main_parameters:
-                        if parameter.grad is not None:
-                            parameter.grad.zero_()
+            if averaging_round and averaging_control is None:
+                averaging_control = super().step(
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
+                )
 
-            self.finished_optimizer_step.set()
+            if wait_for_trigger is not None:
+                wait_for_trigger()
+            began_running = True
+
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
+
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
 
             if averaging_round:
-                if not self.reuse_tensors:
-                    self._load_local_tensors_into_averager_()
-                try:
-                    gathered = super().step(gather=self.local_epoch, **kwargs)
-                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    self.finished_averaging_round.set()
-                    gathered = {}
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
 
-                self.finished_averaging_round.set()
+                    self.finished_averaging_round.set()
 
                 if self.sync_epoch_when_averaging:
                     old_epoch = self.local_epoch
@@ -419,7 +541,12 @@ class TrainingStateAverager(DecentralizedAverager):
                         self._update_scheduler()
 
         except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
+            if averaging_control is not None and not averaging_control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                averaging_control.cancel()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()
 
@@ -428,19 +555,18 @@ class TrainingStateAverager(DecentralizedAverager):
         """Copy local gradients into the gradient buffers of the offloaded optimizer"""
         assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
         opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+        for main_param, opt_param in zip(self.main_parameters, opt_parameters):
             if main_param.grad is not None:
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
     @torch.no_grad()
-    def _apply_optimizer_results_(self):
+    def _apply_optimizer_parameters_(self):
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
-        with self.lock_averaged_tensors:
-            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
-            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
-                main_param.copy_(offloaded_param, non_blocking=True)
+        offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
+        for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
+            main_param.copy_(offloaded_param, non_blocking=True)
 
     @torch.no_grad()
     def _load_local_tensors_into_averager_(self):
@@ -454,24 +580,36 @@ class TrainingStateAverager(DecentralizedAverager):
     def _apply_averaging_results_(self):
         """Copy averaged tensors into their respective local tensors"""
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed")
         with self.get_tensors() as averaged_tensors:
             local_tensors = list(self._local_tensors())
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
-            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                local_tensor.copy_(averaged_tensor, non_blocking=True)
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
 
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         """
-        with torch.no_grad():
+        with torch.no_grad(), self.lock_averaged_tensors:
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
             )
             parameter_infos = [
                 CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
-                for param, key in zip(optimized_parameters, self._parameter_names)
+                for param, key in zip(optimized_parameters, self.parameter_names)
             ]
             extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
             extra_infos = [
@@ -496,8 +634,9 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
-        num_parameters_and_extras = len(parameters_and_extras)
+        opt_parameters = tuple(param for param_group in self.optimizer.param_groups for param in param_group["params"])
+        main_parameters_and_extras = tuple(chain(opt_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)
         if loaded_state is None:
@@ -511,18 +650,24 @@ class TrainingStateAverager(DecentralizedAverager):
         loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
         loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
         if num_parameters_and_extras != len(loaded_parameters_and_extras):
-            logger.error("Failed to load state from peer, received parameters, extras or metadata.")
+            logger.error("Failed to load state from peer, received parameters, extras or metadata")
             return
 
-        try:
-            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
-        except StopIteration:
-            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
-            return
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
 
-        with torch.no_grad():
-            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
                 local_param.copy_(loaded_param, non_blocking=True)
+
+        if self.offload_optimizer:
+            self._apply_optimizer_parameters_()
+        if not self.reuse_tensors:
+            self._load_local_tensors_into_averager_()
+
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()
 

+ 76 - 45
hivemind/optim/grad_scaler.py

@@ -1,83 +1,114 @@
 import contextlib
+import threading
+from copy import deepcopy
 from typing import Dict, Optional
 
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
-from torch.optim import Optimizer
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
+from torch.optim import Optimizer as TorchOptimizer
 
-from hivemind.optim.base import DecentralizedOptimizerBase
+import hivemind
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
 
 
-class HivemindGradScaler(TorchGradScaler):
+class GradScaler(TorchGradScaler):
     """
-    A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:
+
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - limit increasing gradient scale to only immediately after global optimizer steps
-    - allow training with some or all master parameters in fp16
+    - allow training with some or master parameters in float16
+
+    :note: The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly
+      as regular ``torch.amp.GradScaler``.
     """
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
+        self._is_ready_to_update = False
         self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
 
     @contextlib.contextmanager
     def running_global_step(self):
-        was_running, self._is_running_global_step = self._is_running_global_step, True
-        try:
-            yield
-        finally:
-            self._is_running_global_step = was_running
-
-    def unscale_(self, optimizer: Optimizer) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        if self._is_running_global_step:
-            super().unscale_(optimizer.opt)
-            return True
-        else:
-            self._check_inf_per_device(optimizer.opt)
-            self._optimizer_states_to_reset.add(id(optimizer))
-            return False
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
 
-    def step(self, optimizer: Optimizer, *args, **kwargs) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        if self._is_running_global_step:
-            if self.are_grads_finite(optimizer):
-                super().step(optimizer.opt, *args, **kwargs)
+    def unscale_(self, optimizer: TorchOptimizer) -> bool:
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                return True
             else:
-                logger.warning("Skipping global step due to gradient over/underflow")
-            return True
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
+
+    def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
+        if self._is_running_global_step and not isinstance(optimizer, hivemind.Optimizer):
+            # ^-- invoked privately within hivemind optimizer
+            with self._lock:
+                if self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step")
+                assert (
+                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+                if self.are_grads_finite(optimizer, use_cached=True):
+                    super().step(optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
+                return True
         else:
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
 
     def update(self, new_scale: Optional[float] = None) -> bool:
-        total_infs = 0
-        for optimizer_state in self._per_optimizer_states.values():
-            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
-
-        if self._is_running_global_step or total_infs != 0:
-            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
-            super().update(new_scale)
-            return True
-        else:
-            for opt_id in self._optimizer_states_to_reset:
-                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
-            self._optimizer_states_to_reset.clear()
-            return False
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+            if self._is_ready_to_update or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                self._is_ready_to_update = False
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
 
     def _unscale_grads_(
-        self, optimizer: Optimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
+        self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
     ) -> Dict[torch.device, torch.Tensor]:
         # note: the code below sets allow_fp16=True to allow training with master weights (partially) in fp16
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
-    def are_grads_finite(self, optimizer: DecentralizedOptimizerBase) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer.opt).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
+
+
+class HivemindGradScaler(GradScaler):
+    def __init__(self, *args, **kwargs):
+        logger.warning("HivemindGradScaler was renamed to hivemind.GradScaler, this reference will be removed in v1.1")
+        super().__init__(*args, **kwargs)

+ 8 - 4
hivemind/optim/simple.py

@@ -86,6 +86,10 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
             if self.local_step % self.averaging_step_period == 0:
                 self.update_event.set()
             self.averager.pending_updates_done.wait()
+
+            if not self.averager.client_mode:
+                self.averager.state_sharing_priority = get_dht_time()
+
             return loss
         finally:
             self.lock_parameters.acquire()
@@ -127,16 +131,16 @@ class DecentralizedOptimizer(DecentralizedOptimizerBase):
                 time.sleep(time_to_nearest_interval)
 
             if verbose:
-                logger.info(f"Starting a new averaging round with current parameters.")
+                logger.info(f"Starting a new averaging round with current parameters")
             try:
                 group_info = averager.step(lock_parameters, **kwargs)
                 if verbose:
                     if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers.")
+                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
                     else:
-                        logger.warning(f"Averaging round failed: could not find group.")
+                        logger.warning(f"Averaging round failed: could not find group")
             except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}.")
+                logger.error(f"Averaging round failed: caught {e}")
 
 
 class DecentralizedSGD(DecentralizedOptimizer):

+ 1 - 1
hivemind/optim/training_averager.py

@@ -101,7 +101,7 @@ class TrainingAverager(DecentralizedAverager):
                 self.pending_updates_done.clear()
                 with data_lock, self.get_tensors() as averaged_tensors:
                     if len(averaged_tensors) != len(local_tensors):
-                        raise RuntimeError("The number of optimized parameters should not change.")
+                        raise RuntimeError("The number of optimized parameters should not change")
 
                     if use_old_local_tensors:
                         # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,6 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 23 - 9
hivemind/utils/asyncio.py

@@ -114,9 +114,15 @@ async def amap_in_executor(
     queue = asyncio.Queue(max_prefetch)
 
     async def _put_items():
-        async for args in azip(*iterables):
-            await queue.put(loop.run_in_executor(executor, func, *args))
-        await queue.put(None)
+        try:
+            async for args in azip(*iterables):
+                await queue.put(loop.run_in_executor(executor, func, *args))
+            await queue.put(None)
+        except Exception as e:
+            future = asyncio.Future()
+            future.set_exception(e)
+            await queue.put(future)
+            raise
 
     task = asyncio.create_task(_put_items())
     try:
@@ -124,13 +130,21 @@ async def amap_in_executor(
         while future is not None:
             yield await future
             future = await queue.get()
-        await task
     finally:
-        if not task.done():
-            task.cancel()
-
-
-async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: float) -> AsyncIterator[T]:
+        task.cancel()
+        try:
+            await task
+        except asyncio.CancelledError:
+            pass
+        except Exception as e:
+            logger.debug(f"Caught {e} while iterating over inputs", exc_info=True)
+        while not queue.empty():
+            future = queue.get_nowait()
+            if future is not None:
+                future.cancel()
+
+
+async def aiter_with_timeout(iterable: AsyncIterable[T], timeout: Optional[float]) -> AsyncIterator[T]:
     """Iterate over an async iterable, raise TimeoutError if another portion of data does not arrive within timeout"""
     # based on https://stackoverflow.com/a/50245879
     iterator = iterable.__aiter__()

+ 10 - 3
hivemind/utils/logging.py

@@ -15,6 +15,9 @@ if _env_colors is not None:
 else:
     use_colors = sys.stderr.isatty()
 
+_env_log_caller = os.getenv("HIVEMIND_ALWAYS_LOG_CALLER")
+always_log_caller = _env_log_caller is not None and _env_log_caller.lower() == "true"
+
 
 class HandlerMode(Enum):
     NOWHERE = 0
@@ -65,8 +68,12 @@ class CustomFormatter(logging.Formatter):
             record.created = record.origin_created
             record.msecs = (record.created - int(record.created)) * 1000
 
-        if not hasattr(record, "caller"):
-            record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+        if record.levelno != logging.INFO or always_log_caller:
+            if not hasattr(record, "caller"):
+                record.caller = f"{record.name}.{record.funcName}:{record.lineno}"
+            record.caller_block = f" [{TextStyle.BOLD}{record.caller}{TextStyle.RESET}]"
+        else:
+            record.caller_block = ""
 
         # Aliases for the format argument
         record.levelcolor = self._LEVEL_TO_COLOR[record.levelno]
@@ -84,7 +91,7 @@ def _initialize_if_necessary():
             return
 
         formatter = CustomFormatter(
-            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}] [{bold}{caller}{reset}] {message}",
+            fmt="{asctime}.{msecs:03.0f} [{bold}{levelcolor}{levelname}{reset}]{caller_block} {message}",
             style="{",
             datefmt="%b %d %H:%M:%S",
         )

+ 5 - 3
hivemind/utils/mpfuture.py

@@ -138,7 +138,9 @@ class MPFuture(base.Future, Generic[ResultType]):
         async def _event_setter():
             self._aio_event.set()
 
-        if self._loop.is_running() and running_loop == self._loop:
+        if self._loop.is_closed():
+            return  # do nothing, the loop is already closed
+        elif self._loop.is_running() and running_loop == self._loop:
             asyncio.create_task(_event_setter())
         elif self._loop.is_running() and running_loop != self._loop:
             asyncio.run_coroutine_threadsafe(_event_setter(), self._loop)
@@ -201,8 +203,8 @@ class MPFuture(base.Future, Generic[ResultType]):
         try:
             with MPFuture._update_lock if self._use_lock else nullcontext():
                 self._sender_pipe.send((self._uid, update_type, payload))
-        except (ConnectionError, BrokenPipeError, EOFError) as e:
-            logger.debug(f"No updates were sent: pipe to origin process was broken ({e}).", exc_info=True)
+        except (ConnectionError, BrokenPipeError, EOFError, OSError) as e:
+            logger.debug(f"No updates were sent: pipe to origin process was broken ({e})", exc_info=True)
 
     def set_result(self, result: ResultType):
         if os.getpid() == self._origin_pid:

+ 2 - 2
hivemind/utils/serializer.py

@@ -35,7 +35,7 @@ class MSGPackSerializer(SerializerBase):
                 getattr(wrapped_type, "unpackb", None)
             ), f"Every ext_type must have 2 methods: packb(self) -> bytes and classmethod unpackb(cls, bytes)"
             if type_code in cls._ext_type_codes:
-                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting.")
+                logger.warning(f"{cls.__name__}: type {type_code} is already registered, overwriting")
             cls._ext_type_codes[type_code], cls._ext_types[wrapped_type] = wrapped_type, type_code
             return wrapped_type
 
@@ -60,7 +60,7 @@ class MSGPackSerializer(SerializerBase):
         elif type_code == cls._TUPLE_EXT_TYPE_CODE:
             return tuple(msgpack.unpackb(data, ext_hook=cls._decode_ext_types, raw=False))
 
-        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is.")
+        logger.warning(f"Unknown ExtType code: {type_code}, leaving it as is")
         return data
 
     @classmethod

+ 1 - 0
requirements-dev.txt

@@ -4,6 +4,7 @@ pytest-asyncio
 pytest-cov
 tqdm
 scikit-learn
+torchvision
 black==21.6b0
 isort
 psutil

+ 1 - 0
requirements-docs.txt

@@ -1,3 +1,4 @@
 recommonmark==0.5.0
 sphinx_rtd_theme==0.4.3
+docutils==0.16
 sphinx==4.2.0

+ 213 - 0
tests/test_allreduce_fault_tolerance.py

@@ -0,0 +1,213 @@
+from __future__ import annotations
+
+import asyncio
+from enum import Enum, auto
+from typing import AsyncIterator
+
+import pytest
+import torch
+
+import hivemind
+from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
+from hivemind.averaging.averager import *
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import MatchmakingException
+from hivemind.proto import averaging_pb2
+from hivemind.utils.asyncio import aenumerate, as_aiter, azip, enter_asynchronously
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__name__)
+
+
+class Fault(Enum):
+    NONE = auto()
+    FAIL_BEFORE = auto()
+    FAIL_SENDING = auto()
+    SLOW_SENDING = auto()
+    FAIL_REDUCING = auto()
+    SLOW_REDUCING = auto()
+    CANCEL = auto()
+
+
+class FaultyAverager(hivemind.DecentralizedAverager):
+    def __init__(self, *args, fault: Fault = Fault.NONE, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            if self.fault == Fault.FAIL_BEFORE:
+                raise Exception("Oops, I failed!")
+
+            async with enter_asynchronously(self.get_tensors()) as local_tensors:
+                allreduce = FaultyAllReduceRunner(
+                    p2p=self._p2p,
+                    servicer_type=type(self),
+                    prefix=self.prefix,
+                    group_id=group_info.group_id,
+                    tensors=local_tensors,
+                    ordered_peer_ids=group_info.peer_ids,
+                    peer_fractions=peer_fractions,
+                    gathered=user_gathered,
+                    modes=modes,
+                    fault=self.fault,
+                    **kwargs,
+                )
+
+                with self.register_allreduce_group(group_info.group_id, allreduce):
+                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
+                            # all-reduce is performed asynchronously while iterating
+                            tensor.add_(update, alpha=self._averaging_alpha)
+                        self._state_updated.set()
+
+                    else:
+                        async for _ in allreduce:  # trigger all-reduce by iterating
+                            raise ValueError("aux peers should not receive averaged tensors")
+
+                return allreduce.gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+
+class FaultyAllReduceRunner(AllReduceRunner):
+    def __init__(self, *args, fault: Fault, **kwargs):
+        self.fault = fault
+        super().__init__(*args, **kwargs)
+
+    async def rpc_aggregate_part(self, stream, context) -> AsyncIterator[averaging_pb2.AveragingData]:
+        if self.fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING):
+            async for i, message in aenumerate(super().rpc_aggregate_part(stream, context)):
+                yield message
+                if i == 2:
+                    if self.fault == Fault.FAIL_SENDING:
+                        yield averaging_pb2.AveragingData(code=averaging_pb2.INTERNAL_ERROR)
+                        break
+                    else:
+                        await asyncio.sleep(10)
+
+        elif self.fault == Fault.CANCEL:
+            yield averaging_pb2.AveragingData(code=averaging_pb2.CANCELLED)
+        else:
+            async for message in super().rpc_aggregate_part(stream, context):
+                yield message
+
+    async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
+        parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
+
+        first_part = await anext(parts_aiter)
+        yield averaging_pb2.AveragingData(
+            code=averaging_pb2.PART_FOR_AVERAGING,
+            group_id=self.group_id,
+            tensor_part=first_part,
+            weight=self.weight,
+        )
+        if self.fault in (Fault.FAIL_SENDING, Fault.SLOW_SENDING):
+            last_reducer_index = self.group_size - 1 - (self.tensor_part_container.num_parts_by_peer[-1] == 0)
+            if peer_index == last_reducer_index:
+                if self.fault == Fault.FAIL_SENDING:
+                    raise Exception("Oops, I failed!")
+                else:
+                    await asyncio.sleep(10)
+        async for part in parts_aiter:
+            yield averaging_pb2.AveragingData(tensor_part=part, weight=self.weight)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "fault0, fault1",
+    [
+        (Fault.NONE, Fault.FAIL_BEFORE),
+        (Fault.FAIL_BEFORE, Fault.FAIL_BEFORE),
+        (Fault.SLOW_SENDING, Fault.FAIL_SENDING),
+        (Fault.FAIL_SENDING, Fault.FAIL_BEFORE),
+        (Fault.SLOW_REDUCING, Fault.FAIL_SENDING),
+        (Fault.FAIL_REDUCING, Fault.FAIL_REDUCING),
+        (Fault.NONE, Fault.CANCEL),
+    ],
+)
+def test_fault_tolerance(fault0: Fault, fault1: Fault):
+    def _make_tensors():
+        return [torch.rand(16, 1024), -torch.rand(3, 8192), 2 * torch.randn(4, 4, 4), torch.randn(1024, 1024)]
+
+    dht = hivemind.DHT(start=True)
+
+    averagers = []
+    for i in range(5):
+        averager = FaultyAverager(
+            _make_tensors(),
+            hivemind.DHT(initial_peers=dht.get_visible_maddrs(), start=True),
+            prefix="test",
+            request_timeout=0.3,
+            min_matchmaking_time=1.0,
+            next_chunk_timeout=0.5,
+            allreduce_timeout=5,
+            part_size_bytes=2 ** 16,
+            client_mode=(i == 1),
+            start=True,
+            fault=fault0 if i == 0 else fault1 if i == 1 else Fault.NONE,
+        )
+        averagers.append(averager)
+
+    ref_numerators = [0, 0, 0, 0]
+    ref_denominator = 0
+
+    for averager in averagers:
+        if averager.fault not in (Fault.FAIL_BEFORE, Fault.CANCEL):
+            with averager.get_tensors() as tensors:
+                for i, tensor in enumerate(tensors):
+                    ref_numerators[i] = ref_numerators[i] + tensor.clone()
+                ref_denominator += 1
+
+    ref_tensors = [ref_numerator / ref_denominator for ref_numerator in ref_numerators]
+    flat_ref = torch.cat(list(map(torch.flatten, ref_tensors)))
+
+    flat_local_tensors = []
+    for averager in averagers:
+        with averager.get_tensors() as tensors:
+            flat_local_tensors.append(torch.cat(list(map(torch.flatten, tensors))))
+
+    futures = [averager.step(timeout=5, wait=False, allow_retries=False) for averager in averagers]
+    for i, averager in enumerate(averagers):
+        if averager.fault == Fault.CANCEL:
+            futures[i].cancel()
+
+    for future in futures[2:]:
+        assert future.result()
+
+    for averager, prev_local_tensors in zip(averagers[2:], flat_local_tensors[2:]):
+        with averager.get_tensors() as tensors:
+            flat_tensors = torch.cat(list(map(torch.flatten, tensors)))
+
+        diff_with_reference = abs(flat_ref - flat_tensors)
+
+        if all(fault == (Fault.FAIL_SENDING, Fault.SLOW_SENDING) for fault in (fault0, fault1)):
+            assert fault0 != Fault.FAIL_REDUCING and fault1 != Fault.FAIL_REDUCING
+            assert diff_with_reference[: len(diff_with_reference) // 2].max() < 1e-5
+        elif all(fault in (Fault.FAIL_REDUCING, Fault.SLOW_REDUCING) for fault in (fault0, fault1)):
+            diff_to_reference = abs(flat_ref - flat_tensors)
+            diff_to_local = abs(prev_local_tensors - flat_tensors)
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+            assert torch.all(torch.minimum(diff_to_reference, diff_to_local) < 1e-5).item()
+        elif any(fault == Fault.CANCEL for fault in (fault0, fault1)):
+            pass  # late cancel may result in an arbitrary mix of averaging results with and without the cancelled peer
+        elif fault0 == Fault.NONE:  # only peer1 in client mode may have failed
+            assert diff_with_reference.max() < 1e-5
+        else:
+            assert (diff_with_reference < 1e-5).numpy().mean() > 0.5
+
+    for averager in averagers:
+        averager.shutdown()

+ 76 - 2
tests/test_averaging.py

@@ -372,7 +372,6 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
-    dht_instances[1].get("demo-run.all_averagers")
     averager2 = TestAverager(
         [torch.randn(3), torch.rand(5)],
         dht=dht_instances[1],
@@ -381,6 +380,8 @@ def test_load_state_from_peers():
         target_group_size=2,
     )
 
+    time.sleep(0.5)
+
     assert num_calls == 0
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 1
@@ -399,7 +400,9 @@ def test_load_state_from_peers():
 
     averager1.allow_state_sharing = False
     assert averager2.load_state_from_peers() is None
+
     averager1.allow_state_sharing = True
+    time.sleep(0.5)
     got_metadata, got_tensors = averager2.load_state_from_peers()
     assert num_calls == 3
     assert got_metadata == super_metadata
@@ -408,6 +411,47 @@ def test_load_state_from_peers():
         instance.shutdown()
 
 
+@pytest.mark.forked
+def test_load_state_priority():
+    dht_instances = launch_dht_instances(4)
+
+    averagers = []
+    for i in range(4):
+        averager = hivemind.DecentralizedAverager(
+            [torch.randn(3), torch.rand(5), torch.tensor([i], dtype=torch.float32)],
+            dht=dht_instances[i],
+            start=True,
+            prefix="demo-run",
+            target_group_size=2,
+            allow_state_sharing=i != 1,
+        )
+        averager.state_sharing_priority = 5 - abs(2 - i)
+        averagers.append(averager)
+
+    time.sleep(0.5)
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 2
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    averagers[0].state_sharing_priority = 10
+    time.sleep(0.2)
+
+    metadata, tensors = averagers[2].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 0
+
+    averagers[1].allow_state_sharing = False
+    averagers[2].allow_state_sharing = False
+    metadata, tensors = averagers[0].load_state_from_peers(timeout=1)
+    assert tensors[-1].item() == 3
+
+    for averager in averagers:
+        averager.shutdown()
+    for dht in dht_instances:
+        dht.shutdown()
+
+
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
@@ -428,7 +472,6 @@ def test_averaging_trigger():
         hivemind.averaging.DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
-            target_group_size=4,
             min_matchmaking_time=0.5,
             request_timeout=0.3,
             prefix="mygroup",
@@ -468,6 +511,37 @@ def test_averaging_trigger():
     c0.allow_allreduce()
 
 
+@pytest.mark.forked
+def test_averaging_cancel():
+    averagers = tuple(
+        hivemind.averaging.DecentralizedAverager(
+            averaged_tensors=[torch.randn(3)],
+            dht=dht,
+            min_matchmaking_time=0.5,
+            request_timeout=0.3,
+            client_mode=(i % 2 == 0),
+            prefix="mygroup",
+            start=True,
+        )
+        for i, dht in enumerate(launch_dht_instances(4))
+    )
+
+    step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
+
+    time.sleep(0.1)
+    step_controls[0].cancel()
+    step_controls[1].cancel()
+
+    for i, control in enumerate(step_controls):
+        if i in (0, 1):
+            assert control.cancelled()
+        else:
+            assert control.result() is not None and len(control.result()) == 2
+
+    for averager in averagers:
+        averager.shutdown()
+
+
 @pytest.mark.forked
 def test_training_averager(n_steps: int = 10, n_dims: int = 16):
     torch.manual_seed(42)

+ 220 - 7
tests/test_optimizer.py

@@ -1,3 +1,5 @@
+import ctypes
+import multiprocessing as mp
 import time
 from functools import partial
 
@@ -10,7 +12,10 @@ import torch.nn.functional as F
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
+from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
@@ -74,7 +79,7 @@ def test_grad_averager():
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
-    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
 )
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
     dht1 = hivemind.DHT(start=True)
@@ -102,10 +107,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
     )
     avgr2 = TrainingStateAverager(
-        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+        dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
     )
 
     x = torch.ones(2)
@@ -131,8 +136,8 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
 
-    avgr1.step(wait_for_delayed_update=True)
-    avgr2.step(wait_for_delayed_update=True)
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
 
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
@@ -157,10 +162,10 @@ def test_load_state_from_peers():
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
     )
 
-    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
 
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42
@@ -170,3 +175,211 @@ def test_load_state_from_peers():
     assert avgr1.local_epoch == 1337
     assert torch.all(model1.weight == 42).item()
     assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)
+
+
+@pytest.mark.forked
+def test_progress_tracker():
+    # note to a curious reader: no, you cannot reduce the timings without compromising realism or stability
+    prefix = "my_exp"
+    target_batch_size = 256
+    dht_root = hivemind.DHT(start=True)
+    barrier = mp.Barrier(parties=5)
+    delayed_start_evt = mp.Event()
+    finished_evt = mp.Event()
+    emas = mp.Array(ctypes.c_double, 5)
+
+    def run_worker(index: int, batch_size: int, period: float, **kwargs):
+        dht = hivemind.DHT(initial_peers=dht_root.get_visible_maddrs(), start=True)
+        tracker = ProgressTracker(
+            dht,
+            prefix,
+            target_batch_size,
+            start=True,
+            min_refresh_period=0.1,
+            default_refresh_period=0.2,
+            max_refresh_period=0.5,
+            private_key=RSAPrivateKey(),
+            **kwargs,
+        )
+
+        barrier.wait()
+        if index == 4:
+            delayed_start_evt.wait()
+
+        local_epoch = 2 if index == 4 else 0
+        samples_accumulated = 0
+
+        while True:
+            time.sleep(period)
+            if finished_evt.is_set():
+                break
+
+            samples_accumulated += batch_size
+            tracker.report_local_progress(local_epoch, samples_accumulated)
+
+            if tracker.ready_to_update_epoch:
+                if index == 4 and local_epoch >= 4:
+                    time.sleep(0.5)
+                    break
+
+                with tracker.pause_updates():
+                    local_epoch = tracker.update_epoch(local_epoch + 1)
+                    samples_accumulated = 0
+
+        emas[index] = tracker.performance_ema.samples_per_second
+        tracker.shutdown()
+        dht.shutdown()
+
+    workers = [
+        mp.Process(target=run_worker, kwargs=dict(index=1, batch_size=12, period=0.6)),
+        mp.Process(target=run_worker, kwargs=dict(index=2, batch_size=16, period=0.5)),
+        mp.Process(target=run_worker, kwargs=dict(index=3, batch_size=24, period=0.4)),
+        mp.Process(target=run_worker, kwargs=dict(index=4, batch_size=64, period=0.4)),
+    ]
+    for worker in workers:
+        worker.start()
+
+    tracker = ProgressTracker(
+        dht_root,
+        prefix,
+        target_batch_size,
+        start=True,
+        min_refresh_period=0.1,
+        default_refresh_period=0.2,
+        max_refresh_period=0.5,
+    )
+    barrier.wait()
+
+    local_epoch = 0
+    last_timestamp = hivemind.get_dht_time()
+    step_time_deltas = []
+
+    while local_epoch < 6:
+        time.sleep(0.1)
+
+        if tracker.ready_to_update_epoch:
+            with tracker.pause_updates():
+                local_epoch = tracker.update_epoch(local_epoch + 1)
+
+            time_delta = hivemind.get_dht_time() - last_timestamp
+            if local_epoch == 2:
+                delayed_start_evt.set()
+
+            last_timestamp = hivemind.get_dht_time()
+            step_time_deltas.append(time_delta)
+
+    finished_evt.set()
+    for worker in workers:
+        worker.join()
+
+    tracker.shutdown()
+    dht_root.shutdown()
+    assert not tracker.is_alive()
+
+    mean_step_time = sum(step_time_deltas) / len(step_time_deltas)
+    for i in (0, 1, 5):  # Without the 4th worker (the fastest one)
+        assert 1.05 * mean_step_time < step_time_deltas[i] < 2.0 * mean_step_time
+    for i in (2, 3, 4):  # With the 4th worker
+        assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
+    assert emas[1] < emas[2] < emas[3] < emas[4]
+    assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@pytest.mark.forked
+def test_optimizer(
+    num_peers: int = 1,
+    num_clients: int = 0,
+    target_batch_size: int = 32,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=False,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            total_samples_accumulated.value += batch_size
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()