瀏覽代碼

Implement core functionality of hivemind.Optimizer (#403)

This PR implements the main class of hivemind.Optimizer

- implemented main hivemind.Optimizer class
- implemented and tested pre-scheduling for gradient averaging
- implemented and tested pre-scheduling for parameter averaging
- implemented an option to perform local updates
- adapt hivemind.GradScaler to be compatible with new optimizer
- test convergence and stability with 8 peers

Thanks to @SeanNaren for a convenient testing playground in
https://github.com/SeanNaren/hivemind-lightning/blob/main/mingpt.py

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: foksly <mitya1510@ya.ru>
Co-authored-by: Alexey Bukhtiyarov <a.bukhtiyarov@yandex.ru>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 3 年之前
父節點
當前提交
5d31c3ba9d

+ 163 - 0
benchmarks/benchmark_optimizer.py

@@ -0,0 +1,163 @@
+import multiprocessing as mp
+import random
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import numpy as np
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+
+import hivemind
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    run_id: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 256
+    reuse_grad_buffers: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    average_state_every: int = 1
+    use_amp: bool = False
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device: str = "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def benchmark_optimizer(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = train_dataset.data[0].numel()
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id=args.run_id,
+            target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
+            delay_optimizer_step=args.delay_optimizer_step,
+            average_state_every=args.average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        if args.use_amp and args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                daemon=False,
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()
+
+
+if __name__ == "__main__":
+    benchmark_optimizer(TrainingArguments())

+ 29 - 3
docs/modules/optim.rst

@@ -1,14 +1,40 @@
 **hivemind.optim**
 **hivemind.optim**
 ==================
 ==================
 
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 .. raw:: html
 
 
   This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
   This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
   <br><br>
   <br><br>
 
 
+.. automodule:: hivemind.optim.experimental.optimizer
+.. currentmodule:: hivemind.optim.experimental.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+.. raw:: html
+
+  CollaborativeOptimizer is a legacy version of hivemind.Optimizer. **For new projects, please use hivemind.Optimizer.**
+  Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
+  CollaborativeOptimizer will still be supported for awhile, but will eventually be deprecated.
+  <br><br>
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :members: step
    :member-order: bysource
    :member-order: bysource

+ 2 - 0
hivemind/__init__.py

@@ -16,6 +16,8 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
     DecentralizedSGD,
+    GradScaler,
+    Optimizer,
     TrainingAverager,
     TrainingAverager,
 )
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo

+ 28 - 5
hivemind/averaging/averager.py

@@ -35,6 +35,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
+    afirst,
     aiter_with_timeout,
     aiter_with_timeout,
     anext,
     anext,
     as_aiter,
     as_aiter,
@@ -413,11 +414,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             step.attach(trigger, cancel)
             step.attach(trigger, cancel)
             future_for_init.set_result((trigger, cancel))
             future_for_init.set_result((trigger, cancel))
 
 
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                try:
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+                    return group_info
+                except asyncio.CancelledError:
+                    await asyncio.wait(
+                        {
+                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
+                            for peer_id in group_info.peer_ids
+                            if peer_id != self.peer_id
+                        }
+                    )
+                    raise
+
             while not step.done():
             while not step.done():
                 try:
                 try:
                     self._pending_group_assembled.clear()
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
 
 
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -428,13 +446,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         check_cancel_task.cancel()
                         check_cancel_task.cancel()
 
 
                     group_info = await matchmaking_task
                     group_info = await matchmaking_task
+
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
 
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
 
 
                     step.set_result(
                     step.set_result(
@@ -478,6 +493,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                     )
                 )
                 )
 
 
+    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
+        try:
+            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
+            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
+
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:

+ 9 - 0
hivemind/averaging/control.py

@@ -1,3 +1,4 @@
+import os
 import struct
 import struct
 from enum import Enum
 from enum import Enum
 from typing import Optional
 from typing import Optional
@@ -144,6 +145,14 @@ class StepControl(MPFuture):
         self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
         self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
 
 
+    def __del__(self):
+        if os.getpid() == self._origin_pid and not self.triggered:
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
+        super().__del__()
+
     def cancel(self) -> bool:
     def cancel(self) -> bool:
         if self._trigger is not None:
         if self._trigger is not None:
             self._trigger.cancel()
             self._trigger.cancel()

+ 5 - 3
hivemind/averaging/matchmaking.py

@@ -88,9 +88,11 @@ class Matchmaking:
     async def looking_for_group(self, step_control: StepControl):
     async def looking_for_group(self, step_control: StepControl):
         async with self.lock_looking_for_group:
         async with self.lock_looking_for_group:
             assert self.step_control is None
             assert self.step_control is None
-            self.step_control = step_control
-            yield
-            self.step_control = None
+            try:
+                self.step_control = step_control
+                yield
+            finally:
+                self.step_control = None
 
 
     @property
     @property
     def is_looking_for_group(self):
     def is_looking_for_group(self):

+ 1 - 1
hivemind/averaging/partition.py

@@ -35,7 +35,7 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
-        prefetch: int = 5,
+        prefetch: int = 1,
     ):
     ):
         if tensor_infos is None:
         if tensor_infos is None:
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))

+ 2 - 1
hivemind/optim/__init__.py

@@ -1,6 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager
 from hivemind.optim.training_averager import TrainingAverager

+ 2 - 2
hivemind/optim/collaborative.py

@@ -245,7 +245,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
             self.reset_accumulated_grads_()
@@ -310,7 +310,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
 
             if grad_scaler is not None:
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self.opt)
+                    assert grad_scaler.step(self)
             else:
             else:
                 self.opt.step()
                 self.opt.step()
 
 

+ 15 - 10
hivemind/optim/experimental/grad_averager.py

@@ -170,7 +170,13 @@ class GradientAverager(DecentralizedAverager):
         elif len(kwargs) > 0:
         elif len(kwargs) > 0:
             raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
             raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
         assert not control.triggered, f"This {type(control)} instance was already used."
         assert not control.triggered, f"This {type(control)} instance was already used."
-        self._load_accumulators_into_averager_()
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+
+        self.load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
         self._accumulators_used_in_step = True
         self._new_averaged_grads = True
         self._new_averaged_grads = True
 
 
@@ -182,14 +188,8 @@ class GradientAverager(DecentralizedAverager):
         return control.result(timeout) if wait else control
         return control.result(timeout) if wait else control
 
 
     @torch.no_grad()
     @torch.no_grad()
-    def _load_accumulators_into_averager_(self):
+    def load_accumulators_into_averager_(self):
         """load locally accumulated gradients into the averager for aggregation"""
         """load locally accumulated gradients into the averager for aggregation"""
-        if self._new_averaged_grads and self.warn:
-            logger.warning(
-                "[warn=True] Starting new averaging round, but previous round results were not used."
-                "This may be a sign of incorrect optimizer behavior."
-            )
-            self._new_averaged_grads = False  # warn once per round
         # divide locally accumulated gradients by the number of times they were accumulated
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
         with self.get_tensors() as averaged_grads:
@@ -208,14 +208,19 @@ class GradientAverager(DecentralizedAverager):
     @contextlib.contextmanager
     @contextlib.contextmanager
     @torch.no_grad()
     @torch.no_grad()
     def use_averaged_gradients(self):
     def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
         self._new_averaged_grads = False
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
         with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
             try:
             try:
-                assert len(averaged_grads) == len(self.parameters)
                 old_grads = [param.grad for param in self.parameters]
                 old_grads = [param.grad for param in self.parameters]
                 for param, new_grad in zip(self.parameters, averaged_grads):
                 for param, new_grad in zip(self.parameters, averaged_grads):
                     param.grad = new_grad
                     param.grad = new_grad
-                yield
+                yield averaged_grads
             finally:
             finally:
                 for param, old_grad in zip(self.parameters, old_grads):
                 for param, old_grad in zip(self.parameters, old_grads):
                     param.grad = old_grad
                     param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 725 - 0
hivemind/optim/experimental/optimizer.py

@@ -0,0 +1,725 @@
+from __future__ import annotations
+
+import logging
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.compression import CompressionBase, NoCompression
+from hivemind.dht import DHT
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
+from hivemind.optim.experimental.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+    or even fully asynchronous (local_updates=True). However, these options require careful tuning.
+
+    :example: The Optimizer can be used as a drop-in replacement for your regular PyTorch Optimizer:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=lambda params: torch.optim.Adam(params, ...),
+                                 params=model.parameters(), target_batch_size=4096, batch_size_per_step=4)
+    >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
+
+     - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
+     - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
+
+    :example: the optimizer has many keyword arguments that may be difficult to understand in one go. Here's quickstart
+      that will help you setup your first synchronous optimizer.
+
+    >>> hivemind.Optimizer(
+    >>>    dht=hivemind.DHT(initial_peers=ADDRESS_HERE, client_mode=TRUE_IF_BEHIND_FIREWALL_OR_UNRELIABLE, start=True),
+    >>>    run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER,
+    >>>    target_batch_size=LARGE_GLOBAL_BATCH,  # global batch will be this or *slightly* larger due to stragglers;
+    >>>      #  peers should finish averaging in roughly half the time they need to accumulate this batch between them
+    >>>    optimizer=lambda params: AnyPyTorchOptimizer(params, **config_that_makes_sense_for_target_batch_size),
+    >>>      # ^-- scale learning rate for your target_batch_size; good reference: https://arxiv.org/abs/1904.00962
+    >>>    offload_optimizer=True,  # this saves GPU memory; large-batch training does not need optimizer that often
+    >>>    scheduler=lambda opt: AnyPytTorchScheduler(opt, **config_that_makes_sense_for_target_batch_size),
+    >>>      # scheduler.step will be called once every time peers collectively accumulate target_batch_size
+    >>>    matchmaking_time=15.0, averaging_timeout=60.0,  # <-- if the network is fast reduce to 3-5s and 10-15s
+    >>>      # increase matchmaking_time if at least 25% of the time you see "averaged gradients with <...> peers",
+    >>>      # ... but N is less than 0.9x the actual number of peers. Increase averaging_timeout if half of the epochs
+    >>>      # ... print "Proceeding with local gradients" instead of "Averaged gradients with N peers"
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>      # it is generally fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>      # See hivemind/examples/albert for an example of mixed 8-bit compression.
+    >>>    delay_grad_averaging=SHOULD_I_USE_DPU, delay_optimizer_step=SHOULD_I_USE_DPU, # DPU stands for Delayed Para-
+    >>>      # -meter Updates, running allreduce and optimizer step in background. See https://arxiv.org/abs/2101.06840
+    >>>    verbose=True  # periodically report the training progress to the console
+    >>> )  # and you're done!
+
+    :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one caveat:
+      learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
+      (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
+      when others have already made some progress and changed their learning rates accordingly.
+
+    :param dht: a running hivemind.DHT instance connected to other peers
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
+    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
+
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging are not
+      supported if hivemind.optimizer is created with a pre-initialized optimizer and require optimizer factory
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
+
+    :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+      The above 3 options (offload_optimizer, delay_optimizer_step and delay_grad_averaging) require that the optimizer
+      is created with: ``hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())``
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
+    :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
+    :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
+    :param verbose: if True, report internal events such as accumilating gradients and running background tasks
+
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        run_id: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 60.0,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        offload_optimizer: Optional[bool] = None,
+        delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
+        client_mode: bool = None,
+        auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
+        shutdown_timeout: float = 5,
+        verbose: bool = False,
+    ):
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+
+        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
+        self.shutdown_timeout = shutdown_timeout
+
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
+
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
+        )
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
+        self._schema_hash = self._compute_schema_hash()
+        self._parent_pid = os.getpid()
+
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
+        self._step_supports_amp_scaling = reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
+            status_loglevel=self.status_loglevel,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.run_id,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        return self.state_averager.local_epoch
+
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
+
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
+
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[GradScaler] = None,
+    ):
+        """
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
+
+        :param closure: A closure that reevaluates the model and returns the loss
+        :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        if not self.auxiliary and self.should_load_state_from_peers():
+            logger.log(self.status_loglevel, "Peer is out of sync.")
+            self.load_state_from_peers()
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
+
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
+
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
+
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
+
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
+
+        return loss
+
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
+
+        with self.tracker.pause_updates():
+            wait_for_trigger = None
+
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    pass  # failed to start gradient averaging due to an internal error
+                elif self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
+
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
+
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
+                    self.scheduled_state = None
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
+                grad_scaler=grad_scaler,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+            )
+
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+            logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        began_averaging_gradients = False
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
+            self.scheduled_grads = None
+
+        elif self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
+            self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                if self.delay_grad_averaging:
+                    # wait for previous averaging to finish before starting a new one
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = estimated_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                self.scheduled_state = self.state_averager.schedule_step(
+                    gather=next_epoch, timeout=self.averaging_timeout
+                )
+
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            logger.log(self.status_loglevel, f"Proceeding with local gradients")
+            self.grad_averager.load_accumulators_into_averager_()
+            self._load_averaged_gradients_into_optimizer_()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+
+        self.grad_averager.notify_used_averaged_gradients()
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally."
+            )
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
+        self._finish_background_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def _finish_background_averaging(self):
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                if not scheduled_round.triggered:
+                    scheduled_round.weight = 0
+                    scheduled_round.allow_allreduce()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None and not scheduled_round.done():
+                try:
+                    time_to_deadline = scheduled_round.deadline - get_dht_time()
+                    scheduled_round.result(timeout=max(0.0, time_to_deadline))
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
+                if not scheduled_round.done():
+                    scheduled_round.cancel()
+        self.scheduled_grads = self.scheduled_state = None
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation."
+            f"Please provide all parameter groups at init."
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
+
+    def shutdown(self):
+        logger.log(self.status_loglevel, "Sending goodbye to peers...")
+        self.tracker.shutdown(self.shutdown_timeout)
+        self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
+        logger.log(self.status_loglevel, "Shutting down averagers...")
+        self.state_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
+        logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down.")
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 47 - 14
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         *,
         client_mode: Optional[bool] = None,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
         expected_drift_rate: float = 0.2,
@@ -114,7 +114,7 @@ class ProgressTracker(threading.Thread):
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
-        self.should_report_progress = threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         if start:
         if start:
@@ -150,15 +150,20 @@ class ProgressTracker(threading.Thread):
             client_mode=self.client_mode,
             client_mode=self.client_mode,
         )
         )
 
 
-    def report_local_progress(self, local_epoch: int, samples_accumulated: int):
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
         """Update the number of locally accumulated samples and notify to other peers about this."""
         """Update the number of locally accumulated samples and notify to other peers about this."""
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
         if extra_samples > 0:
         if extra_samples > 0:
             self.performance_ema.update(task_size=extra_samples)
             self.performance_ema.update(task_size=extra_samples)
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
         else:
         else:
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             self.performance_ema.reset_timer()
             self.performance_ema.reset_timer()
+
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.should_report_progress.set()
         self.should_report_progress.set()
 
 
@@ -178,6 +183,7 @@ class ProgressTracker(threading.Thread):
             self.global_progress.samples_accumulated = 0
             self.global_progress.samples_accumulated = 0
             self.global_progress.eta_next_epoch = float("inf")
             self.global_progress.eta_next_epoch = float("inf")
         self.report_local_progress(new_epoch, samples_accumulated=0)
         self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.fetched_global_progress_this_epoch.clear()
         return new_epoch
         return new_epoch
 
 
     def run(self):
     def run(self):
@@ -189,6 +195,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
         last_report_time = -float("inf")
+        store_task = None
         try:
         try:
             while not self.shutdown_triggered.is_set():
             while not self.shutdown_triggered.is_set():
                 wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
                 wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
@@ -203,21 +210,42 @@ class ProgressTracker(threading.Thread):
                 local_progress = self.local_progress
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
                 last_report_time = get_dht_time()
 
 
-                await self.dht.store(
-                    key=self.training_progress_key,
-                    subkey=self._local_public_key,
-                    value=local_progress.dict(),
-                    expiration_time=last_report_time + self.metadata_expiration,
-                    return_future=True,
+                store_task = asyncio.create_task(
+                    asyncio.wait_for(
+                        self.dht.store(
+                            key=self.training_progress_key,
+                            subkey=self._local_public_key,
+                            value=local_progress.dict(),
+                            expiration_time=last_report_time + self.metadata_expiration,
+                            return_future=True,
+                        ),
+                        timeout=self.metadata_expiration,
+                    )
                 )
                 )
         finally:
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
+            if store_task is not None:
+                store_task.cancel()
 
 
     async def _progress_fetcher(self):
     async def _progress_fetcher(self):
         """
         """
         Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
         Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
         """
         """
         loop = asyncio.get_event_loop()
         loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(
+            asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
+        )
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.create_task(
+                asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            )
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
         try:
         try:
             while not self.shutdown_triggered.is_set():
             while not self.shutdown_triggered.is_set():
                 time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
                 time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
@@ -229,9 +257,13 @@ class ProgressTracker(threading.Thread):
                     continue
                     continue
 
 
                 async with enter_asynchronously(self.lock_global_progress):
                 async with enter_asynchronously(self.lock_global_progress):
-                    progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
-                    metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
                     self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.fetched_global_progress_this_epoch.set()
+
         finally:
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
 
 
@@ -294,7 +326,7 @@ class ProgressTracker(threading.Thread):
         )
         )
         logger.log(
         logger.log(
             self.status_loglevel,
             self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         )
         return GlobalTrainingProgress(
         return GlobalTrainingProgress(
@@ -307,15 +339,16 @@ class ProgressTracker(threading.Thread):
             next_fetch_time=current_time + time_to_next_fetch,
             next_fetch_time=current_time + time_to_next_fetch,
         )
         )
 
 
-    def shutdown(self):
+    def shutdown(self, timeout: Optional[float] = None):
         """Permanently disable all tracking activity"""
         """Permanently disable all tracking activity"""
         self.shutdown_triggered.set()
         self.shutdown_triggered.set()
         self.should_report_progress.set()
         self.should_report_progress.set()
         self.global_state_updated.set()
         self.global_state_updated.set()
-        self.shutdown_complete.wait()
+        self.shutdown_complete.wait(timeout)
         self.dht.store(
         self.dht.store(
             self.training_progress_key,
             self.training_progress_key,
             subkey=self._local_public_key,
             subkey=self._local_public_key,
             value=None,
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
             expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True,
         )
         )

+ 234 - 111
hivemind/optim/experimental/state_averager.py

@@ -1,18 +1,20 @@
 """ An extension of averager that supports common optimization use cases. """
 """ An extension of averager that supports common optimization use cases. """
 import logging
 import logging
-from asyncio import Future
+import threading
+import time
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
-from threading import Event
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 
 
 import hivemind
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_logger, nested_flatten, nested_pack
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -36,7 +38,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
 
     Example:
     Example:
 
 
-    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
     >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
     >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
     >>> avgr.load_state_from_peers()
     >>> avgr.load_state_from_peers()
     >>> for i, batch in enumerate(training_dataloader):
     >>> for i, batch in enumerate(training_dataloader):
@@ -49,7 +51,7 @@ class TrainingStateAverager(DecentralizedAverager):
       TrainingStateAverager.step(..., optimizer_step=True)
       TrainingStateAverager.step(..., optimizer_step=True)
 
 
     :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
     :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
-    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
     :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
     :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
     :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
     :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
     :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
     :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
@@ -60,8 +62,11 @@ class TrainingStateAverager(DecentralizedAverager):
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
-    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param parameter_names: optionally provide parameter names in the same order as in params
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
     :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
@@ -73,12 +78,14 @@ class TrainingStateAverager(DecentralizedAverager):
         *,
         *,
         dht: hivemind.DHT,
         dht: hivemind.DHT,
         optimizer: Union[TorchOptimizer, OptimizerFactory],
         optimizer: Union[TorchOptimizer, OptimizerFactory],
-        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        params: Optional[Union[Parameters, ParamGroups]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         initialize_optimizer: Optional[bool] = None,
         initialize_optimizer: Optional[bool] = None,
         offload_optimizer: bool = False,
         offload_optimizer: bool = False,
         custom_gradients: bool = False,
         custom_gradients: bool = False,
-        reuse_tensors: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
         sync_epoch_when_averaging: bool = False,
         sync_epoch_when_averaging: bool = False,
         parameter_names: Optional[Sequence[str]] = None,
         parameter_names: Optional[Sequence[str]] = None,
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
@@ -88,20 +95,22 @@ class TrainingStateAverager(DecentralizedAverager):
     ):
     ):
         average_opt_statistics = tuple(average_opt_statistics)
         average_opt_statistics = tuple(average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
-        if offload_optimizer and reuse_tensors:
-            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
         if custom_gradients and not offload_optimizer:
         if custom_gradients and not offload_optimizer:
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
 
 
-        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
 
 
         self.status_loglevel = status_loglevel
         self.status_loglevel = status_loglevel
-        self.reuse_tensors = reuse_tensors
-        self.offload_optimizer = offload_optimizer
-        self.custom_gradients = custom_gradients
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
-        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
         self.optimizer, self.scheduler = self._init_components(
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
             param_groups, optimizer, scheduler, initialize_optimizer
         )
         )
@@ -109,11 +118,13 @@ class TrainingStateAverager(DecentralizedAverager):
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.local_epoch = 0
         self.local_epoch = 0
 
 
-        self.step_executor = ThreadPoolExecutor(max_workers=1)
-        self.finished_optimizer_step = Event()
-        self.finished_averaging_round = Event()
-        self.pending_update = Future()
-        self.pending_update.set_result(None)
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
+        self.pending_updates = set()
 
 
         super().__init__(
         super().__init__(
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
@@ -143,10 +154,15 @@ class TrainingStateAverager(DecentralizedAverager):
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         return param_groups, parameters, parameter_names
         return param_groups, parameters, parameter_names
 
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors:
-            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+        if self.reuse_tensors and not force_copy:
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():
             if not source_tensor.is_shared():
                 source_tensor.share_memory_()
                 source_tensor.share_memory_()
             return source_tensor
             return source_tensor
@@ -173,19 +189,26 @@ class TrainingStateAverager(DecentralizedAverager):
         # create optimizer
         # create optimizer
         if optimizer_is_factory:
         if optimizer_is_factory:
             if self.offload_optimizer:
             if self.offload_optimizer:
-                for param in self._averaged_parameters:
-                    if param.grad is None:
-                        param.grad = torch.zeros_like(param)
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
 
 
                 next_index = 0
                 next_index = 0
                 param_groups_for_optimizer = []
                 param_groups_for_optimizer = []
                 for param_group in param_groups:
                 for param_group in param_groups:
                     num_params = len(param_group["params"])
                     num_params = len(param_group["params"])
-                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     next_index += num_params
                     next_index += num_params
-                assert next_index == len(self._averaged_parameters)
+                assert next_index == len(parameters_for_optimizer)
 
 
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
             else:
             else:
                 param_groups_for_optimizer = param_groups
                 param_groups_for_optimizer = param_groups
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
@@ -198,7 +221,7 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.log(
             logger.log(
                 self.status_loglevel,
                 self.status_loglevel,
                 "Initializing optimizer manually since it has no tensors in state dict. "
                 "Initializing optimizer manually since it has no tensors in state dict. "
-                "To override this, please provide initialize_optimizer=False",
+                "To override this, provide initialize_optimizer=False",
             )
             )
 
 
         if initialize_optimizer:
         if initialize_optimizer:
@@ -213,7 +236,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
 
         # verify optimizer and scheduler
         # verify optimizer and scheduler
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
-        if self.offload_optimizer or self.reuse_tensors:
+        if self.reuse_tensors:
             for param_group in optimizer.param_groups:
             for param_group in optimizer.param_groups:
                 for param in param_group["params"]:
                 for param in param_group["params"]:
                     assert param.is_shared()
                     assert param.is_shared()
@@ -250,7 +273,7 @@ class TrainingStateAverager(DecentralizedAverager):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
             assert local_tensor.shape == averaged_tensor.shape
             assert local_tensor.shape == averaged_tensor.shape
             if averaged_tensor.grad is not None:
             if averaged_tensor.grad is not None:
-                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
 
 
         return averaged_tensors
         return averaged_tensors
 
 
@@ -274,9 +297,22 @@ class TrainingStateAverager(DecentralizedAverager):
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
         return tuple(tensor_infos)
         return tuple(tensor_infos)
 
 
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
     def step(
     def step(
         self,
         self,
-        wait_for_delayed_update: bool = None,
+        wait_for_delayed_updates: bool = None,
         apply_delayed_updates: bool = True,
         apply_delayed_updates: bool = True,
         increment_epoch: bool = False,
         increment_epoch: bool = False,
         optimizer_step: bool = False,
         optimizer_step: bool = False,
@@ -284,6 +320,8 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
         delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
     ):
@@ -291,138 +329,205 @@ class TrainingStateAverager(DecentralizedAverager):
         Perform one or several possible actions, depending on the specified keyword args.
         Perform one or several possible actions, depending on the specified keyword args.
         The actions will be performed in the same order as specified below:
         The actions will be performed in the same order as specified below:
 
 
-        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
-        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         """
         if delay_averaging is None:
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
             delay_averaging = delay_optimizer_step
-        if wait_for_delayed_update is None:
-            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
-        if optimizer_step or averaging_round or zero_grad:
-            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
         if delay_optimizer_step:
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
+        if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
+            if not (self.reuse_tensors or self.custom_gradients):
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details."
+                )
         output = None
         output = None
 
 
-        if wait_for_delayed_update:
-            if not self.pending_update.done():
-                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                output = self.pending_update.result()
-
-        if self.pending_update.done() and self.pending_update.exception():
-            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result()
+                except BaseException:
+                    pass  # exception will be reported below
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
 
 
         if apply_delayed_updates:
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
                     self._apply_averaging_results_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
                 self.finished_averaging_round.clear()
 
 
             if self.finished_optimizer_step.is_set():
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
                 if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
+                    self._apply_optimizer_parameters_()
+                logger.debug("Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
                 self.finished_optimizer_step.clear()
 
 
         if increment_epoch:
         if increment_epoch:
             self.local_epoch += 1
             self.local_epoch += 1
 
 
         if optimizer_step or zero_grad or averaging_round:
         if optimizer_step or zero_grad or averaging_round:
-            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
-
             if self.offload_optimizer and not self.custom_gradients:
             if self.offload_optimizer and not self.custom_gradients:
                 self._load_local_grads_into_optimizer_()
                 self._load_local_grads_into_optimizer_()
 
 
-            self.pending_update = self.step_executor.submit(
+            pending_update = self.step_executor.submit(
                 self._do,
                 self._do,
+                wait_for_trigger,
                 optimizer_step,
                 optimizer_step,
                 zero_grad,
                 zero_grad,
                 averaging_round,
                 averaging_round,
+                averaging_control,
                 grad_scaler,
                 grad_scaler,
                 **averaging_opts or {},
                 **averaging_opts or {},
             )
             )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
 
 
-            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+            if should_await_optimizer:
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
                 self.finished_optimizer_step.clear()
-                if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Finished optimizer step")
+                if self.offload_optimizer and not should_await_averaging:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Finished optimizer step")
 
 
-            if averaging_round and not delay_averaging:
+            if should_await_averaging:
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.clear()
                 self.finished_averaging_round.clear()
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
                     self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished averaging round")
                 logger.log(self.status_loglevel, "Finished averaging round")
 
 
-            if not delay_averaging:
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
                 try:
                 try:
-                    output = self.pending_update.result()
+                    output = pending_update.result()
                 finally:
                 finally:
-                    self.finished_averaging_round.clear()
-                    self.finished_optimizer_step.clear()
+                    self.pending_updates.remove(pending_update)
+
         return output
         return output
 
 
     def _do(
     def _do(
-        self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
+        self,
+        wait_for_trigger: Optional[Callable[[], Any]],
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        averaging_control: Optional[StepControl],
+        grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
+        **kwargs,
     ):
     ):
         """
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         This method is meant to be called in the background executor.
         """
         """
-        try:
-            if optimizer_step:
-                logger.log(self.status_loglevel, f"Running optimizer step")
-                if grad_scaler is None:
-                    self.optimizer.step()
-                else:
-                    with grad_scaler.running_global_step():
-                        assert grad_scaler.step(self.optimizer)
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
 
 
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
+        start_time = time.perf_counter()
+        began_running = False
 
 
-            self._update_scheduler()
-
-            if zero_grad:
-                logger.log(self.status_loglevel, f"Running zero grad")
-                self.optimizer.zero_grad()
-                if self.offload_optimizer:
-                    for parameter in self.main_parameters:
-                        if parameter.grad is not None:
-                            parameter.grad.zero_()
+        try:
+            if averaging_round and averaging_control is None:
+                averaging_control = super().step(
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
+                )
 
 
-            self.finished_optimizer_step.set()
+            if wait_for_trigger is not None:
+                wait_for_trigger()
+            began_running = True
+
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
+
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
 
 
             if averaging_round:
             if averaging_round:
-                if not self.reuse_tensors:
-                    self._load_local_tensors_into_averager_()
-                try:
-                    gathered = super().step(gather=self.local_epoch, **kwargs)
-                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    self.finished_averaging_round.set()
-                    gathered = {}
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
 
 
-                self.finished_averaging_round.set()
+                    self.finished_averaging_round.set()
 
 
                 if self.sync_epoch_when_averaging:
                 if self.sync_epoch_when_averaging:
                     old_epoch = self.local_epoch
                     old_epoch = self.local_epoch
@@ -433,7 +538,12 @@ class TrainingStateAverager(DecentralizedAverager):
                         self._update_scheduler()
                         self._update_scheduler()
 
 
         except Exception as e:
         except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
             logger.exception(e)
+            if averaging_control is not None and not averaging_control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                averaging_control.cancel()
             self.finished_optimizer_step.set()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()
             self.finished_averaging_round.set()
 
 
@@ -447,16 +557,13 @@ class TrainingStateAverager(DecentralizedAverager):
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
 
     @torch.no_grad()
     @torch.no_grad()
-    def _apply_optimizer_results_(self):
+    def _apply_optimizer_parameters_(self):
         """Copy parameters from offloaded optimizer to the main model"""
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
-        with self.lock_averaged_tensors:
-            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(offloaded_parameters) == len(
-                self.main_parameters
-            ), "Optimizer parameters changed during training"
-            for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
-                main_param.copy_(offloaded_param, non_blocking=True)
+        offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
+        for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
+            main_param.copy_(offloaded_param, non_blocking=True)
 
 
     @torch.no_grad()
     @torch.no_grad()
     def _load_local_tensors_into_averager_(self):
     def _load_local_tensors_into_averager_(self):
@@ -470,18 +577,30 @@ class TrainingStateAverager(DecentralizedAverager):
     def _apply_averaging_results_(self):
     def _apply_averaging_results_(self):
         """Copy averaged tensors into their respective local tensors"""
         """Copy averaged tensors into their respective local tensors"""
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed.")
         with self.get_tensors() as averaged_tensors:
         with self.get_tensors() as averaged_tensors:
             local_tensors = list(self._local_tensors())
             local_tensors = list(self._local_tensors())
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
-            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                local_tensor.copy_(averaged_tensor, non_blocking=True)
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
 
 
     def get_current_state(self):
     def get_current_state(self):
         """
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         """
         """
-        with torch.no_grad():
+        with torch.no_grad(), self.lock_averaged_tensors:
             optimized_parameters = tuple(
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
                 param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
             )
             )
@@ -512,8 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         :returns: whether or the averager succeeded in loading parameters
         """
         """
-        parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
-        num_parameters_and_extras = len(parameters_and_extras)
+        main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
 
 
         loaded_state = super().load_state_from_peers(**kwargs)
         loaded_state = super().load_state_from_peers(**kwargs)
         if loaded_state is None:
         if loaded_state is None:
@@ -530,15 +649,19 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.error("Failed to load state from peer, received parameters, extras or metadata.")
             logger.error("Failed to load state from peer, received parameters, extras or metadata.")
             return
             return
 
 
-        try:
-            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
-        except StopIteration:
-            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
-            return
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
 
 
-        with torch.no_grad():
-            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
                 local_param.copy_(loaded_param, non_blocking=True)
                 local_param.copy_(loaded_param, non_blocking=True)
+
+        if self.offload_optimizer:
+            self._apply_optimizer_parameters_()
+
         self.local_epoch = metadata["epoch"]
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()
         self._update_scheduler()
 
 

+ 60 - 36
hivemind/optim/grad_scaler.py

@@ -1,12 +1,14 @@
 import contextlib
 import contextlib
+import threading
+from copy import deepcopy
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
 import torch
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
 from torch.cuda.amp import GradScaler as TorchGradScaler
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
 from torch.optim import Optimizer as TorchOptimizer
 from torch.optim import Optimizer as TorchOptimizer
 
 
-from hivemind.optim.base import DecentralizedOptimizerBase
+import hivemind
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -14,7 +16,12 @@ logger = get_logger(__name__)
 
 
 class GradScaler(TorchGradScaler):
 class GradScaler(TorchGradScaler):
     """
     """
-    A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    GradScaler removes several:
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - limit increasing gradient scale to only immediately after global optimizer steps
     - limit increasing gradient scale to only immediately after global optimizer steps
     - allow training with some or all master parameters in fp16
     - allow training with some or all master parameters in fp16
@@ -23,52 +30,68 @@ class GradScaler(TorchGradScaler):
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
         self._is_running_global_step = False
+        self._is_ready_to_update = False
         self._optimizer_states_to_reset = set()
         self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def running_global_step(self):
     def running_global_step(self):
-        was_running, self._is_running_global_step = self._is_running_global_step, True
-        try:
-            yield
-        finally:
-            self._is_running_global_step = was_running
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
 
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        if self._is_running_global_step:
-            super().unscale_(optimizer.opt)
-            return True
-        else:
-            self._check_inf_per_device(optimizer.opt)
-            self._optimizer_states_to_reset.add(id(optimizer))
-            return False
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                return True
+            else:
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
 
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
         if self._is_running_global_step:
-            if self.are_grads_finite(optimizer):
-                super().step(optimizer, *args, **kwargs)
-            else:
-                logger.warning("Skipping global step due to gradient over/underflow")
-            return True
+            with self._lock:
+                if self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step.")
+                assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+                assert (
+                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+                if self.are_grads_finite(optimizer, use_cached=True):
+                    super().step(optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
+                return True
         else:
         else:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             super().step(optimizer)
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
             return False
 
 
     def update(self, new_scale: Optional[float] = None) -> bool:
     def update(self, new_scale: Optional[float] = None) -> bool:
-        total_infs = 0
-        for optimizer_state in self._per_optimizer_states.values():
-            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
-
-        if self._is_running_global_step or total_infs != 0:
-            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
-            super().update(new_scale)
-            return True
-        else:
-            for opt_id in self._optimizer_states_to_reset:
-                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
-            self._optimizer_states_to_reset.clear()
-            return False
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+            if self._is_ready_to_update or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                self._is_ready_to_update = False
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
 
 
     def _unscale_grads_(
     def _unscale_grads_(
         self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
         self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
@@ -77,8 +100,9 @@ class GradScaler(TorchGradScaler):
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
 
-    def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
 
 
 
 
 class HivemindGradScaler(GradScaler):
 class HivemindGradScaler(GradScaler):

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,6 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 1 - 0
requirements-dev.txt

@@ -4,6 +4,7 @@ pytest-asyncio
 pytest-cov
 pytest-cov
 tqdm
 tqdm
 scikit-learn
 scikit-learn
+torchvision
 black==21.6b0
 black==21.6b0
 isort
 isort
 psutil
 psutil

+ 1 - 1
tests/test_averaging.py

@@ -528,7 +528,7 @@ def test_averaging_cancel():
 
 
     step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
     step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
 
 
-    time.sleep(0.2)
+    time.sleep(0.1)
     step_controls[0].cancel()
     step_controls[0].cancel()
     step_controls[1].cancel()
     step_controls[1].cancel()
 
 

+ 120 - 16
tests/test_optimizer.py

@@ -12,6 +12,7 @@ import torch.nn.functional as F
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
@@ -78,7 +79,7 @@ def test_grad_averager():
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
-    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
 )
 )
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
     dht1 = hivemind.DHT(start=True)
     dht1 = hivemind.DHT(start=True)
@@ -106,10 +107,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     )
     )
 
 
     avgr1 = TrainingStateAverager(
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
     )
     )
     avgr2 = TrainingStateAverager(
     avgr2 = TrainingStateAverager(
-        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+        dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
     )
     )
 
 
     x = torch.ones(2)
     x = torch.ones(2)
@@ -135,8 +136,8 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
 
 
-    avgr1.step(wait_for_delayed_update=True)
-    avgr2.step(wait_for_delayed_update=True)
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
 
 
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
@@ -161,10 +162,10 @@ def test_load_state_from_peers():
     )
     )
 
 
     avgr1 = TrainingStateAverager(
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
     )
     )
 
 
-    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
 
 
     avgr2.local_epoch = 1337
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42
     model2.weight.data[...] = 42
@@ -217,14 +218,14 @@ def test_progress_tracker():
             tracker.report_local_progress(local_epoch, samples_accumulated)
             tracker.report_local_progress(local_epoch, samples_accumulated)
 
 
             if tracker.ready_to_update_epoch:
             if tracker.ready_to_update_epoch:
+                if index == 4 and local_epoch >= 4:
+                    time.sleep(0.5)
+                    break
+
                 with tracker.pause_updates():
                 with tracker.pause_updates():
                     local_epoch = tracker.update_epoch(local_epoch + 1)
                     local_epoch = tracker.update_epoch(local_epoch + 1)
                     samples_accumulated = 0
                     samples_accumulated = 0
 
 
-                if index == 4 and local_epoch >= 5:
-                    time.sleep(0.5)
-                    break
-
         emas[index] = tracker.performance_ema.samples_per_second
         emas[index] = tracker.performance_ema.samples_per_second
         tracker.shutdown()
         tracker.shutdown()
         dht.shutdown()
         dht.shutdown()
@@ -249,16 +250,19 @@ def test_progress_tracker():
     )
     )
     barrier.wait()
     barrier.wait()
 
 
-    current_step = 0
+    local_epoch = 0
     last_timestamp = hivemind.get_dht_time()
     last_timestamp = hivemind.get_dht_time()
     step_time_deltas = []
     step_time_deltas = []
 
 
-    while current_step < 6:
+    while local_epoch < 6:
         time.sleep(0.1)
         time.sleep(0.1)
-        if tracker.global_progress.epoch > current_step:
+
+        if tracker.ready_to_update_epoch:
+            with tracker.pause_updates():
+                local_epoch = tracker.update_epoch(local_epoch + 1)
+
             time_delta = hivemind.get_dht_time() - last_timestamp
             time_delta = hivemind.get_dht_time() - last_timestamp
-            current_step = tracker.global_progress.epoch
-            if current_step == 2:
+            if local_epoch == 2:
                 delayed_start_evt.set()
                 delayed_start_evt.set()
 
 
             last_timestamp = hivemind.get_dht_time()
             last_timestamp = hivemind.get_dht_time()
@@ -279,3 +283,103 @@ def test_progress_tracker():
         assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
         assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
     assert emas[1] < emas[2] < emas[3] < emas[4]
     assert emas[1] < emas[2] < emas[3] < emas[4]
     assert tracker.performance_ema.samples_per_second < 1e-9
     assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@pytest.mark.forked
+def test_optimizer(
+    num_peers: int = 1,
+    num_clients: int = 0,
+    target_batch_size: int = 32,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=False,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            total_samples_accumulated.value += batch_size
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()