瀏覽代碼

Implement core functionality of hivemind.Optimizer (#403)

This PR implements the main class of hivemind.Optimizer

- implemented main hivemind.Optimizer class
- implemented and tested pre-scheduling for gradient averaging
- implemented and tested pre-scheduling for parameter averaging
- implemented an option to perform local updates
- adapt hivemind.GradScaler to be compatible with new optimizer
- test convergence and stability with 8 peers

Thanks to @SeanNaren for a convenient testing playground in
https://github.com/SeanNaren/hivemind-lightning/blob/main/mingpt.py

Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: foksly <mitya1510@ya.ru>
Co-authored-by: Alexey Bukhtiyarov <a.bukhtiyarov@yandex.ru>
Co-authored-by: Michael Diskin <yhn1124@gmail.com>
justheuristic 3 年之前
父節點
當前提交
5d31c3ba9d

+ 163 - 0
benchmarks/benchmark_optimizer.py

@@ -0,0 +1,163 @@
+import multiprocessing as mp
+import random
+import time
+from contextlib import nullcontext
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable
+
+import numpy as np
+import torch
+import torchvision
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.utils.data import Dataset
+
+import hivemind
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.utils.crypto import RSAPrivateKey
+
+
+@dataclass(frozen=True)
+class TrainingArguments:
+    seed: int = 42
+    run_id: str = "my_exp"
+
+    num_peers: int = 8
+    num_clients: int = 3
+    target_batch_size: int = 256
+    reuse_grad_buffers: bool = True
+    delay_grad_averaging: bool = True
+    delay_optimizer_step: bool = True
+    average_state_every: int = 1
+    use_amp: bool = False
+
+    lr_base: float = 0.1
+    lr_gamma: int = 0.1
+    lr_step_size: int = 10
+    max_epoch: int = 25
+
+    batch_size_min: int = 2
+    batch_size_max: int = 16
+    batch_time_min: float = 1.0
+    batch_time_max: float = 4.5
+    batch_time_std: float = 0.5
+
+    matchmaking_time: float = 5.0
+    max_refresh_period: float = 5.0
+    averaging_timeout: float = 15.0
+    winddown_time: float = 5.0
+    verbose: bool = True
+
+    device: str = "cpu"
+    make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
+    make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
+        nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
+    )
+
+
+def benchmark_optimizer(args: TrainingArguments):
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.set_num_threads(1)
+
+    dht = hivemind.DHT(start=True)
+
+    train_dataset = args.make_dataset()
+    num_features = train_dataset.data[0].numel()
+    num_classes = len(train_dataset.classes)
+    X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
+    X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
+    y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
+    del train_dataset
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
+        model = args.make_model(num_features, num_classes).to(args.device)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id=args.run_id,
+            target_batch_size=args.target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=args.lr_base),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
+            matchmaking_time=args.matchmaking_time,
+            averaging_timeout=args.averaging_timeout,
+            reuse_grad_buffers=args.reuse_grad_buffers,
+            delay_grad_averaging=args.delay_grad_averaging,
+            delay_optimizer_step=args.delay_optimizer_step,
+            average_state_every=args.average_state_every,
+            client_mode=client_mode,
+            verbose=verbose,
+        )
+
+        if args.use_amp and args.reuse_grad_buffers:
+            grad_scaler = hivemind.GradScaler()
+        else:
+            # check that hivemind.Optimizer supports regular PyTorch grad scaler as well
+            grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < args.max_epoch:
+            time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))
+
+            batch = torch.randint(0, len(X_train), (batch_size,))
+
+            with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
+                loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
+                grad_scaler.scale(loss).backward()
+
+            grad_scaler.unscale_(optimizer)
+
+            if args.use_amp:
+                grad_scaler.step(optimizer)
+            else:
+                optimizer.step()
+
+            grad_scaler.update()
+
+            if not args.reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(args.winddown_time)
+        optimizer.shutdown()
+
+    peers = []
+
+    for index in range(args.num_peers):
+        batch_size = random.randint(args.batch_size_min, args.batch_size_max)
+        batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                daemon=False,
+                kwargs=dict(
+                    batch_size=batch_size,
+                    batch_time=batch_time,
+                    client_mode=(index >= args.num_peers - args.num_clients),
+                    verbose=args.verbose and (index == 0),
+                ),
+            )
+        )
+
+    try:
+        for peer in peers[1:]:
+            peer.start()
+        peers[0].run()
+        for peer in peers[1:]:
+            peer.join()
+    finally:
+        for peer in peers[1:]:
+            peer.kill()
+
+
+if __name__ == "__main__":
+    benchmark_optimizer(TrainingArguments())

+ 29 - 3
docs/modules/optim.rst

@@ -1,14 +1,40 @@
 **hivemind.optim**
 ==================
 
-.. automodule:: hivemind.optim
-.. currentmodule:: hivemind.optim
-
 .. raw:: html
 
   This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
   <br><br>
 
+.. automodule:: hivemind.optim.experimental.optimizer
+.. currentmodule:: hivemind.optim.experimental.optimizer
+
+**hivemind.Optimizer**
+----------------------
+
+.. autoclass:: Optimizer
+   :members: step, zero_grad, load_state_from_peers, param_groups, shutdown
+   :member-order: bysource
+
+.. currentmodule:: hivemind.optim.grad_scaler
+.. autoclass:: GradScaler
+   :member-order: bysource
+
+
+**CollaborativeOptimizer**
+--------------------------
+
+.. raw:: html
+
+  CollaborativeOptimizer is a legacy version of hivemind.Optimizer. **For new projects, please use hivemind.Optimizer.**
+  Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
+  CollaborativeOptimizer will still be supported for awhile, but will eventually be deprecated.
+  <br><br>
+
+
+.. automodule:: hivemind.optim.collaborative
+.. currentmodule:: hivemind.optim
+
 .. autoclass:: CollaborativeOptimizer
    :members: step
    :member-order: bysource

+ 2 - 0
hivemind/__init__.py

@@ -16,6 +16,8 @@ from hivemind.optim import (
     DecentralizedOptimizer,
     DecentralizedOptimizerBase,
     DecentralizedSGD,
+    GradScaler,
+    Optimizer,
     TrainingAverager,
 )
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo

+ 28 - 5
hivemind/averaging/averager.py

@@ -35,6 +35,7 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
+    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -413,11 +414,28 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             step.attach(trigger, cancel)
             future_for_init.set_result((trigger, cancel))
 
+            async def find_peers_or_notify_cancel():
+                group_info = await self._matchmaking.look_for_group(step)
+                try:
+                    if not step.triggered:
+                        step.stage = AveragingStage.AWAITING_TRIGGER
+                        await step.wait_for_trigger()
+                    return group_info
+                except asyncio.CancelledError:
+                    await asyncio.wait(
+                        {
+                            self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
+                            for peer_id in group_info.peer_ids
+                            if peer_id != self.peer_id
+                        }
+                    )
+                    raise
+
             while not step.done():
                 try:
                     self._pending_group_assembled.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
-                    matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
+                    matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
 
                     await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
@@ -428,13 +446,10 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                         check_cancel_task.cancel()
 
                     group_info = await matchmaking_task
+
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group.")
 
-                    if not step.triggered:
-                        step.stage = AveragingStage.AWAITING_TRIGGER
-                        await step.wait_for_trigger()
-
                     step.stage = AveragingStage.RUNNING_ALLREDUCE
 
                     step.set_result(
@@ -478,6 +493,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
+    async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
+        try:
+            error = averaging_pb2.AveragingData(group_id=group_id, code=code)
+            stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
+            await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
+        except Exception as e:
+            logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")
+
     async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:

+ 9 - 0
hivemind/averaging/control.py

@@ -1,3 +1,4 @@
+import os
 import struct
 from enum import Enum
 from typing import Optional
@@ -144,6 +145,14 @@ class StepControl(MPFuture):
         self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
         self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]
 
+    def __del__(self):
+        if os.getpid() == self._origin_pid and not self.triggered:
+            logger.warning(
+                "Deleted an averaging StepControl, but the step was not triggered. This may cause other "
+                "peers to fail an averaging round via TimeoutError."
+            )
+        super().__del__()
+
     def cancel(self) -> bool:
         if self._trigger is not None:
             self._trigger.cancel()

+ 5 - 3
hivemind/averaging/matchmaking.py

@@ -88,9 +88,11 @@ class Matchmaking:
     async def looking_for_group(self, step_control: StepControl):
         async with self.lock_looking_for_group:
             assert self.step_control is None
-            self.step_control = step_control
-            yield
-            self.step_control = None
+            try:
+                self.step_control = step_control
+                yield
+            finally:
+                self.step_control = None
 
     @property
     def is_looking_for_group(self):

+ 1 - 1
hivemind/averaging/partition.py

@@ -35,7 +35,7 @@ class TensorPartContainer:
         compression: CompressionBase = NoCompression(),
         part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
-        prefetch: int = 5,
+        prefetch: int = 1,
     ):
         if tensor_infos is None:
             tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))

+ 2 - 1
hivemind/optim/__init__.py

@@ -1,6 +1,7 @@
 from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.grad_scaler import HivemindGradScaler
+from hivemind.optim.experimental.optimizer import Optimizer
+from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager

+ 2 - 2
hivemind/optim/collaborative.py

@@ -245,7 +245,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
             self.averager.local_step = self.collaboration_state.optimizer_step
             logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
             logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
             self.local_samples_accumulated = self.local_steps_accumulated = 0
             self.reset_accumulated_grads_()
@@ -310,7 +310,7 @@ class CollaborativeOptimizer(DecentralizedOptimizerBase):
 
             if grad_scaler is not None:
                 with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self.opt)
+                    assert grad_scaler.step(self)
             else:
                 self.opt.step()
 

+ 15 - 10
hivemind/optim/experimental/grad_averager.py

@@ -170,7 +170,13 @@ class GradientAverager(DecentralizedAverager):
         elif len(kwargs) > 0:
             raise RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
         assert not control.triggered, f"This {type(control)} instance was already used."
-        self._load_accumulators_into_averager_()
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+
+        self.load_accumulators_into_averager_()
         self._accumulators_used_in_step = True
         self._new_averaged_grads = True
 
@@ -182,14 +188,8 @@ class GradientAverager(DecentralizedAverager):
         return control.result(timeout) if wait else control
 
     @torch.no_grad()
-    def _load_accumulators_into_averager_(self):
+    def load_accumulators_into_averager_(self):
         """load locally accumulated gradients into the averager for aggregation"""
-        if self._new_averaged_grads and self.warn:
-            logger.warning(
-                "[warn=True] Starting new averaging round, but previous round results were not used."
-                "This may be a sign of incorrect optimizer behavior."
-            )
-            self._new_averaged_grads = False  # warn once per round
         # divide locally accumulated gradients by the number of times they were accumulated
         grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
         with self.get_tensors() as averaged_grads:
@@ -208,14 +208,19 @@ class GradientAverager(DecentralizedAverager):
     @contextlib.contextmanager
     @torch.no_grad()
     def use_averaged_gradients(self):
+        """Substitute model's main gradients with averaged gradients (does not respect device placement)"""
         self._new_averaged_grads = False
         with self.get_tensors() as averaged_grads:
+            assert len(averaged_grads) == len(self.parameters)
             try:
-                assert len(averaged_grads) == len(self.parameters)
                 old_grads = [param.grad for param in self.parameters]
                 for param, new_grad in zip(self.parameters, averaged_grads):
                     param.grad = new_grad
-                yield
+                yield averaged_grads
             finally:
                 for param, old_grad in zip(self.parameters, old_grads):
                     param.grad = old_grad
+
+    def notify_used_averaged_gradients(self):
+        """Notify averager that the results of a previous averaging round are accounted for"""
+        self._new_averaged_grads = False

+ 725 - 0
hivemind/optim/experimental/optimizer.py

@@ -0,0 +1,725 @@
+from __future__ import annotations
+
+import logging
+import os
+import time
+from functools import partial
+from typing import Callable, Optional, Sequence, Union
+
+import torch
+
+from hivemind.averaging.control import AveragingStage, StepControl
+from hivemind.compression import CompressionBase, NoCompression
+from hivemind.dht import DHT
+from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.progress_tracker import ProgressTracker
+from hivemind.optim.experimental.state_averager import (
+    LRSchedulerBase,
+    OptimizerFactory,
+    Parameters,
+    ParamGroups,
+    SchedulerFactory,
+    TorchOptimizer,
+    TrainingStateAverager,
+)
+from hivemind.optim.grad_scaler import GradScaler
+from hivemind.utils import PerformanceEMA, get_dht_time, get_logger
+
+logger = get_logger(__name__)
+
+
+class Optimizer(torch.optim.Optimizer):
+    """
+    Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+    or even fully asynchronous (local_updates=True). However, these options require careful tuning.
+
+    :example: The Optimizer can be used as a drop-in replacement for your regular PyTorch Optimizer:
+
+    >>> model = transformers.AutoModel("albert-xxlarge-v2")
+    >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
+    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=lambda params: torch.optim.Adam(params, ...),
+                                 params=model.parameters(), target_batch_size=4096, batch_size_per_step=4)
+    >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
+    >>> while True:
+    >>>     loss = compute_loss_on_batch(model, batch_size=4)
+    >>>     opt.zero_grad()
+    >>>     loss.backward()
+    >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
+
+    However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
+
+     - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+     - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
+     - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
+
+    :example: the optimizer has many keyword arguments that may be difficult to understand in one go. Here's quickstart
+      that will help you setup your first synchronous optimizer.
+
+    >>> hivemind.Optimizer(
+    >>>    dht=hivemind.DHT(initial_peers=ADDRESS_HERE, client_mode=TRUE_IF_BEHIND_FIREWALL_OR_UNRELIABLE, start=True),
+    >>>    run_id="a_unique_name_that_every_participant_will_see_when_training",
+    >>>    batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER,
+    >>>    target_batch_size=LARGE_GLOBAL_BATCH,  # global batch will be this or *slightly* larger due to stragglers;
+    >>>      #  peers should finish averaging in roughly half the time they need to accumulate this batch between them
+    >>>    optimizer=lambda params: AnyPyTorchOptimizer(params, **config_that_makes_sense_for_target_batch_size),
+    >>>      # ^-- scale learning rate for your target_batch_size; good reference: https://arxiv.org/abs/1904.00962
+    >>>    offload_optimizer=True,  # this saves GPU memory; large-batch training does not need optimizer that often
+    >>>    scheduler=lambda opt: AnyPytTorchScheduler(opt, **config_that_makes_sense_for_target_batch_size),
+    >>>      # scheduler.step will be called once every time peers collectively accumulate target_batch_size
+    >>>    matchmaking_time=15.0, averaging_timeout=60.0,  # <-- if the network is fast reduce to 3-5s and 10-15s
+    >>>      # increase matchmaking_time if at least 25% of the time you see "averaged gradients with <...> peers",
+    >>>      # ... but N is less than 0.9x the actual number of peers. Increase averaging_timeout if half of the epochs
+    >>>      # ... print "Proceeding with local gradients" instead of "Averaged gradients with N peers"
+    >>>    grad_compression=hivemind.Float16Compression(),  state_averaging_compression=hivemind.Float16Compression(),
+    >>>      # it is generally fine to use pure 16-bit or even lower precision during communication with no precaution;
+    >>>      # See hivemind/examples/albert for an example of mixed 8-bit compression.
+    >>>    delay_grad_averaging=SHOULD_I_USE_DPU, delay_optimizer_step=SHOULD_I_USE_DPU, # DPU stands for Delayed Para-
+    >>>      # -meter Updates, running allreduce and optimizer step in background. See https://arxiv.org/abs/2101.06840
+    >>>    verbose=True  # periodically report the training progress to the console
+    >>> )  # and you're done!
+
+    :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one caveat:
+      learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
+      (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
+      when others have already made some progress and changed their learning rates accordingly.
+
+    :param dht: a running hivemind.DHT instance connected to other peers
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys.
+      **Note:** peers with the same run_id should *generally* train the same model and use compatible configurations.
+      Some options can be safely changed by individual peers: ``batch_size_per_step``, ``client_mode``, ``auxiliary``,
+      ``reuse_grad_buffers``, ``offload_optimizer``, and ``verbose``. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
+    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
+
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
+      **Note:** some advanced options like offload_optimizer, delay_optimizer_step, or delay_grad_averaging are not
+      supported if hivemind.optimizer is created with a pre-initialized optimizer and require optimizer factory
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler.
+      The learning rate scheduler will adjust learning rate based on global epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
+
+    :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
+    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
+    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
+    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
+      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+      The above 3 options (offload_optimizer, delay_optimizer_step and delay_grad_averaging) require that the optimizer
+      is created with: ``hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())``
+
+    :param average_state_every: average state (parameters, chosen opt tensors) with peers every this many **epochs**.
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large.
+      The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. In turn, network failures, lossy
+      gradient compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients.
+      Even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
+    :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
+    :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
+    :param performance_ema_alpha: moving average alpha  in ProgressTracer, TrainingStateAverager and Optimizer
+    :param verbose: if True, report internal events such as accumilating gradients and running background tasks
+
+    :note: in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer
+      is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.
+
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: DHT,
+        run_id: str,
+        target_batch_size: int,
+        batch_size_per_step: Optional[int] = None,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        params: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        matchmaking_time: Optional[float] = 15.0,
+        averaging_timeout: Optional[float] = 60.0,
+        load_state_timeout: float = 600.0,
+        reuse_grad_buffers: bool = False,
+        offload_optimizer: Optional[bool] = None,
+        delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
+        client_mode: bool = None,
+        auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        averager_opts: Optional[dict] = None,
+        tracker_opts: Optional[dict] = None,
+        performance_ema_alpha: float = 0.1,
+        shutdown_timeout: float = 5,
+        verbose: bool = False,
+    ):
+        client_mode = client_mode if client_mode is None else dht.client_mode
+        delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
+        assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
+        assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+
+        self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
+        self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
+        self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
+
+        self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
+        self.shutdown_timeout = shutdown_timeout
+
+        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
+
+        self.tracker = self._make_progress_tracker(
+            target_batch_size, performance_ema_alpha=performance_ema_alpha, **tracker_opts or {}
+        )
+        self.state_averager = self._make_state_averager(
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            performance_ema_alpha=performance_ema_alpha,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
+        )
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
+        self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
+        self._schema_hash = self._compute_schema_hash()
+        self._parent_pid = os.getpid()
+
+        self.delay_before_state_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        # measures the average time from the beginning of self._update_global_epoch to the call to state_averager
+        # used for pre-scheduling the averaging round in state_averager
+
+        self._step_supports_amp_scaling = reuse_grad_buffers
+        # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
+        # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
+
+    def _make_state_averager(self, **kwargs) -> TrainingStateAverager:
+        return TrainingStateAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_state_averager",
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
+            status_loglevel=self.status_loglevel,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+
+    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+        assert hasattr(self, "state_averager"), "must initialize state averager first"
+        grad_averager = GradientAverager(
+            dht=self.dht,
+            prefix=f"{self.run_id}_grad_averager",
+            parameters=self.state_averager.main_parameters,
+            min_matchmaking_time=self.matchmaking_time,
+            allreduce_timeout=self.averaging_timeout,
+            shutdown_timeout=self.shutdown_timeout,
+            client_mode=self.client_mode,
+            auxiliary=self.auxiliary,
+            start=True,
+            **kwargs,
+        )
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
+        return grad_averager
+
+    def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
+        return ProgressTracker(
+            dht=self.dht,
+            prefix=self.run_id,
+            target_batch_size=target_batch_size,
+            client_mode=self.client_mode,
+            status_loglevel=self.status_loglevel,
+            start=True,
+            **kwargs,
+        )
+
+    def _compute_schema_hash(self) -> int:
+        optimized_param_groups = self.state_averager.optimizer.param_groups
+        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+        param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
+        return hash((grad_ids, param_shapes))
+
+    def is_alive(self) -> bool:
+        return self.state_averager.is_alive()
+
+    @property
+    def local_epoch(self) -> int:
+        return self.state_averager.local_epoch
+
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
+
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
+
+    def step(
+        self,
+        closure: Optional[Callable[[], torch.Tensor]] = None,
+        batch_size: Optional[int] = None,
+        grad_scaler: Optional[GradScaler] = None,
+    ):
+        """
+        Update training progress after accumulating another local batch size. Depending on the configuration, this will
+        report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.
+
+        :param closure: A closure that reevaluates the model and returns the loss
+        :param batch_size: optional override for batch_size_per_step from init
+        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
+        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
+        """
+        if grad_scaler is not None and not isinstance(grad_scaler, GradScaler):
+            raise ValueError("hivemind.Optimizer requires a hivemind-aware gradient scaler (hivemind.GradScaler)")
+        if self.batch_size_per_step is None and batch_size is None and not self.auxiliary:
+            raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
+        if self.auxiliary and (closure is not None or batch_size is not None or grad_scaler is not None):
+            raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
+        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
+
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
+        loss = None
+        if closure is not None:
+            with torch.enable_grad():
+                loss = closure()
+
+        if not self.auxiliary and self.should_load_state_from_peers():
+            logger.log(self.status_loglevel, "Peer is out of sync.")
+            self.load_state_from_peers()
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
+
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
+
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
+
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
+
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
+
+        return loss
+
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
+        assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
+        _epoch_start_time = time.perf_counter()
+
+        with self.tracker.pause_updates():
+            wait_for_trigger = None
+
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    pass  # failed to start gradient averaging due to an internal error
+                elif self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
+
+            next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = (
+                swarm_not_empty
+                and next_epoch % self.average_state_every == 0
+                and not self.state_averager.averaging_in_progress
+            )
+
+            if should_average_state and self.scheduled_state is not None:
+                if self.scheduled_state.triggered or self.scheduled_state.done():
+                    logger.log(
+                        self.status_loglevel,
+                        f"Not using pre-scheduled group for state averaging because it"
+                        f"was already used elsewhere: {self.scheduled_state}",
+                    )
+                    self.scheduled_state = None
+                self.delay_before_state_averaging.update(task_size=1, interval=time.perf_counter() - _epoch_start_time)
+
+            self.state_averager.step(
+                increment_epoch=True,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
+                grad_scaler=grad_scaler,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
+            )
+
+            if not should_average_state and self.scheduled_state is not None and not self.scheduled_state.done():
+                self.scheduled_state.cancel()
+            self.scheduled_state = None
+
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+            logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        began_averaging_gradients = False
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(
+                self.status_loglevel,
+                f"Not using pre-scheduled group for state averaging because it"
+                f"was already used elsewhere: {self.scheduled_state}",
+            )
+            self.scheduled_grads = None
+
+        elif self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            logger.log(self.status_loglevel, f"Cancelled pre-scheduled gradient averaging round")
+            self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                if self.delay_grad_averaging:
+                    # wait for previous averaging to finish before starting a new one
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
+                self.scheduled_grads = self.grad_averager.schedule_step(timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+        if self.state_averager.averaging_in_progress:
+            return  # previous run is still in progress
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.delay_before_state_averaging.ema_seconds_per_sample
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = estimated_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                self.scheduled_state = self.state_averager.schedule_step(
+                    gather=next_epoch, timeout=self.averaging_timeout
+                )
+
+    def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
+        """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
+        averaged_gradients = False
+
+        try:
+            if maybe_step_control is not None:
+                group_info = maybe_step_control.result(self.averaging_timeout)
+                logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
+                averaged_gradients = True
+            else:
+                logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
+        except BaseException as e:
+            logger.log(self.status_loglevel, f"Averaging gradients failed with {repr(e)}")
+
+        if not averaged_gradients:
+            logger.log(self.status_loglevel, f"Proceeding with local gradients")
+            self.grad_averager.load_accumulators_into_averager_()
+            self._load_averaged_gradients_into_optimizer_()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
+
+        self.grad_averager.notify_used_averaged_gradients()
+
+    def zero_grad(self, set_to_none: bool = False):
+        """Reset gradients from model. If reuse_grad_buffers=True, this will raise an error."""
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
+            raise ValueError(
+                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
+                f"call zero_grad manually. Gradients will be refreshed internally."
+            )
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to load the newest collaboration state from other peers within the same run_id.
+
+        If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.
+        """
+        self._finish_background_averaging()
+        self.state_averager.step(wait_for_delayed_updates=True)
+
+        with self.tracker.pause_updates():
+            while True:
+                try:
+                    self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    break
+                except KeyboardInterrupt:
+                    raise
+                except BaseException as e:
+                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
+                    continue
+
+            if self.tracker.global_epoch - 1 <= self.local_epoch < self.tracker.global_epoch:
+                logger.log(self.status_loglevel, f"Catching up with collaboration step {self.tracker.global_epoch}.")
+                self.state_averager.local_epoch = self.tracker.global_epoch
+
+            self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
+
+            if not self.client_mode:
+                self.state_averager.state_sharing_priority = self.local_epoch
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def _finish_background_averaging(self):
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None:
+                if scheduled_round.stage == AveragingStage.LOOKING_FOR_GROUP:
+                    scheduled_round.cancel()
+                if not scheduled_round.triggered:
+                    scheduled_round.weight = 0
+                    scheduled_round.allow_allreduce()
+        for scheduled_round in self.scheduled_grads, self.scheduled_state:
+            if scheduled_round is not None and not scheduled_round.done():
+                try:
+                    time_to_deadline = scheduled_round.deadline - get_dht_time()
+                    scheduled_round.result(timeout=max(0.0, time_to_deadline))
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Caught {e} while averaging gradients")
+                if not scheduled_round.done():
+                    scheduled_round.cancel()
+        self.scheduled_grads = self.scheduled_state = None
+
+    def state_dict(self) -> dict:
+        state_dict = self.state_averager.optimizer.state_dict()
+        state_dict["state"]["local_epoch"] = self.local_epoch
+        return state_dict
+
+    def load_state_dict(self, state_dict: dict):
+        if "local_epoch" in state_dict["state"]:
+            self.state_averager.local_epoch = state_dict["state"].pop("local_epoch")
+        return self.state_averager.optimizer.load_state_dict(state_dict)
+
+    @property
+    def state(self):
+        return dict(self.state_averager.optimizer.state, local_epoch=self.local_epoch)
+
+    @property
+    def opt(self) -> TorchOptimizer:
+        return self.state_averager.optimizer
+
+    @property
+    def param_groups(self) -> ParamGroups:
+        next_index = 0
+        param_groups = tuple(dict(param_group) for param_group in self.state_averager.optimizer.param_groups)
+        for param_group in param_groups:
+            num_params = len(param_group["params"])
+            main_params_for_group = self.state_averager.main_parameters[next_index : next_index + num_params]
+            param_group["params"] = main_params_for_group
+            next_index += num_params
+        assert next_index == len(self.state_averager.main_parameters)
+        return param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(
+            f"{self.__class__.__name__} does not support calling add_param_group after creation."
+            f"Please provide all parameter groups at init."
+        )
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(prefix={self.run_id}, epoch={self.local_epoch})"
+
+    def shutdown(self):
+        logger.log(self.status_loglevel, "Sending goodbye to peers...")
+        self.tracker.shutdown(self.shutdown_timeout)
+        self.state_averager.step(wait_for_delayed_updates=True)
+        self._finish_background_averaging()
+        logger.log(self.status_loglevel, "Shutting down averagers...")
+        self.state_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
+        logger.log(self.status_loglevel, f"{self.__class__.__name__} is shut down.")
+
+    def __del__(self):
+        if self._parent_pid == os.getpid() and self.is_alive():
+            self.shutdown()

+ 47 - 14
hivemind/optim/experimental/progress_tracker.py

@@ -83,7 +83,7 @@ class ProgressTracker(threading.Thread):
         *,
         client_mode: Optional[bool] = None,
         min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
+        max_refresh_period: float = 10,
         default_refresh_period: float = 3,
         expected_drift_peers: float = 3,
         expected_drift_rate: float = 0.2,
@@ -114,7 +114,7 @@ class ProgressTracker(threading.Thread):
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
-        self.should_report_progress = threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         if start:
@@ -150,15 +150,20 @@ class ProgressTracker(threading.Thread):
             client_mode=self.client_mode,
         )
 
-    def report_local_progress(self, local_epoch: int, samples_accumulated: int):
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
         """Update the number of locally accumulated samples and notify to other peers about this."""
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
         if extra_samples > 0:
             self.performance_ema.update(task_size=extra_samples)
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
         else:
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             self.performance_ema.reset_timer()
+
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.should_report_progress.set()
 
@@ -178,6 +183,7 @@ class ProgressTracker(threading.Thread):
             self.global_progress.samples_accumulated = 0
             self.global_progress.eta_next_epoch = float("inf")
         self.report_local_progress(new_epoch, samples_accumulated=0)
+        self.fetched_global_progress_this_epoch.clear()
         return new_epoch
 
     def run(self):
@@ -189,6 +195,7 @@ class ProgressTracker(threading.Thread):
     async def _progress_reporter(self):
         """Periodically publish metadata and the current number of samples accumulated towards the next epoch"""
         last_report_time = -float("inf")
+        store_task = None
         try:
             while not self.shutdown_triggered.is_set():
                 wait_timeout = max(0.0, last_report_time + self.metadata_expiration - get_dht_time())
@@ -203,21 +210,42 @@ class ProgressTracker(threading.Thread):
                 local_progress = self.local_progress
                 last_report_time = get_dht_time()
 
-                await self.dht.store(
-                    key=self.training_progress_key,
-                    subkey=self._local_public_key,
-                    value=local_progress.dict(),
-                    expiration_time=last_report_time + self.metadata_expiration,
-                    return_future=True,
+                store_task = asyncio.create_task(
+                    asyncio.wait_for(
+                        self.dht.store(
+                            key=self.training_progress_key,
+                            subkey=self._local_public_key,
+                            value=local_progress.dict(),
+                            expiration_time=last_report_time + self.metadata_expiration,
+                            return_future=True,
+                        ),
+                        timeout=self.metadata_expiration,
+                    )
                 )
         finally:
             logger.log(self.status_loglevel, f"No longer reporting progress for {self.prefix}.")
+            if store_task is not None:
+                store_task.cancel()
 
     async def _progress_fetcher(self):
         """
         Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
         """
         loop = asyncio.get_event_loop()
+        shutdown_checker = asyncio.create_task(
+            asyncio.wait_for(loop.run_in_executor(None, self.shutdown_triggered.wait), None)
+        )
+
+        async def _fetch_progress_unless_shutdown_triggered():
+            """Fetch progress, avoid deadlocks if DHT was shut down before this get finished."""
+            getter = asyncio.create_task(
+                asyncio.wait_for(self.dht.get(self.training_progress_key, latest=True, return_future=True), None)
+            )
+            await asyncio.wait({getter, shutdown_checker}, return_when=asyncio.FIRST_COMPLETED)
+            if self.shutdown_triggered.is_set():
+                return
+            return await getter
+
         try:
             while not self.shutdown_triggered.is_set():
                 time_to_next_update = max(0.0, self.global_progress.next_fetch_time - get_dht_time())
@@ -229,9 +257,13 @@ class ProgressTracker(threading.Thread):
                     continue
 
                 async with enter_asynchronously(self.lock_global_progress):
-                    progress_entry = await self.dht.get(self.training_progress_key, latest=True, return_future=True)
-                    metadata = progress_entry.value if isinstance(progress_entry, ValueWithExpiration) else None
+                    maybe_metadata = await _fetch_progress_unless_shutdown_triggered()
+                    if self.shutdown_triggered.is_set():
+                        break
+                    metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
+                    self.fetched_global_progress_this_epoch.set()
+
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
 
@@ -294,7 +326,7 @@ class ProgressTracker(threading.Thread):
         )
         logger.log(
             self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return GlobalTrainingProgress(
@@ -307,15 +339,16 @@ class ProgressTracker(threading.Thread):
             next_fetch_time=current_time + time_to_next_fetch,
         )
 
-    def shutdown(self):
+    def shutdown(self, timeout: Optional[float] = None):
         """Permanently disable all tracking activity"""
         self.shutdown_triggered.set()
         self.should_report_progress.set()
         self.global_state_updated.set()
-        self.shutdown_complete.wait()
+        self.shutdown_complete.wait(timeout)
         self.dht.store(
             self.training_progress_key,
             subkey=self._local_public_key,
             value=None,
             expiration_time=get_dht_time() + self.metadata_expiration,
+            return_future=True,
         )

+ 234 - 111
hivemind/optim/experimental/state_averager.py

@@ -1,18 +1,20 @@
 """ An extension of averager that supports common optimization use cases. """
 import logging
-from asyncio import Future
+import threading
+import time
 from concurrent.futures import ThreadPoolExecutor
+from contextlib import nullcontext
 from itertools import chain
-from threading import Event
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
 
 import hivemind
 from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_logger, nested_flatten, nested_pack
+from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
 
 logger = get_logger(__name__)
 
@@ -36,7 +38,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
     Example:
 
-    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, params=model.parameters(), ...)
     >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
     >>> avgr.load_state_from_peers()
     >>> for i, batch in enumerate(training_dataloader):
@@ -49,7 +51,7 @@ class TrainingStateAverager(DecentralizedAverager):
       TrainingStateAverager.step(..., optimizer_step=True)
 
     :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
-    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
     :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
     :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
     :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
@@ -60,8 +62,11 @@ class TrainingStateAverager(DecentralizedAverager):
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
-    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param parameter_names: optionally provide parameter names in the same order as in params
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
     :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
@@ -73,12 +78,14 @@ class TrainingStateAverager(DecentralizedAverager):
         *,
         dht: hivemind.DHT,
         optimizer: Union[TorchOptimizer, OptimizerFactory],
-        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        params: Optional[Union[Parameters, ParamGroups]] = None,
         scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
         initialize_optimizer: Optional[bool] = None,
         offload_optimizer: bool = False,
         custom_gradients: bool = False,
-        reuse_tensors: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
         sync_epoch_when_averaging: bool = False,
         parameter_names: Optional[Sequence[str]] = None,
         average_opt_statistics: Sequence[str] = (),
@@ -88,20 +95,22 @@ class TrainingStateAverager(DecentralizedAverager):
     ):
         average_opt_statistics = tuple(average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
-        if offload_optimizer and reuse_tensors:
-            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
         if custom_gradients and not offload_optimizer:
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
 
-        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
 
         self.status_loglevel = status_loglevel
-        self.reuse_tensors = reuse_tensors
-        self.offload_optimizer = offload_optimizer
-        self.custom_gradients = custom_gradients
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
-        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self._averaged_parameters = self._make_averaged_parameters(main_parameters)
         self.optimizer, self.scheduler = self._init_components(
             param_groups, optimizer, scheduler, initialize_optimizer
         )
@@ -109,11 +118,13 @@ class TrainingStateAverager(DecentralizedAverager):
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.local_epoch = 0
 
-        self.step_executor = ThreadPoolExecutor(max_workers=1)
-        self.finished_optimizer_step = Event()
-        self.finished_averaging_round = Event()
-        self.pending_update = Future()
-        self.pending_update.set_result(None)
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.lock_averaging = threading.Lock()
+        self.pending_updates = set()
 
         super().__init__(
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
@@ -143,10 +154,15 @@ class TrainingStateAverager(DecentralizedAverager):
         assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
         return param_groups, parameters, parameter_names
 
-    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+    def _make_averaged_parameters(self, main_parameters: Sequence[torch.Tensor]):
+        """Initialize averaged parameters based on the optimizer and averaging mode"""
+        return tuple(self._make_host_tensor(param, force_copy=self.offload_optimizer) for param in main_parameters)
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor, force_copy: bool = False) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
-        if self.reuse_tensors:
-            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+        if self.reuse_tensors and not force_copy:
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():
                 source_tensor.share_memory_()
             return source_tensor
@@ -173,19 +189,26 @@ class TrainingStateAverager(DecentralizedAverager):
         # create optimizer
         if optimizer_is_factory:
             if self.offload_optimizer:
-                for param in self._averaged_parameters:
-                    if param.grad is None:
-                        param.grad = torch.zeros_like(param)
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
 
                 next_index = 0
                 param_groups_for_optimizer = []
                 for param_group in param_groups:
                     num_params = len(param_group["params"])
-                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     next_index += num_params
-                assert next_index == len(self._averaged_parameters)
+                assert next_index == len(parameters_for_optimizer)
 
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
             else:
                 param_groups_for_optimizer = param_groups
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
@@ -198,7 +221,7 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.log(
                 self.status_loglevel,
                 "Initializing optimizer manually since it has no tensors in state dict. "
-                "To override this, please provide initialize_optimizer=False",
+                "To override this, provide initialize_optimizer=False",
             )
 
         if initialize_optimizer:
@@ -213,7 +236,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
         # verify optimizer and scheduler
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
-        if self.offload_optimizer or self.reuse_tensors:
+        if self.reuse_tensors:
             for param_group in optimizer.param_groups:
                 for param in param_group["params"]:
                     assert param.is_shared()
@@ -250,7 +273,7 @@ class TrainingStateAverager(DecentralizedAverager):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
             assert local_tensor.shape == averaged_tensor.shape
             if averaged_tensor.grad is not None:
-                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
 
         return averaged_tensors
 
@@ -274,9 +297,22 @@ class TrainingStateAverager(DecentralizedAverager):
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
         return tuple(tensor_infos)
 
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
     def step(
         self,
-        wait_for_delayed_update: bool = None,
+        wait_for_delayed_updates: bool = None,
         apply_delayed_updates: bool = True,
         increment_epoch: bool = False,
         optimizer_step: bool = False,
@@ -284,6 +320,8 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
+        wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
     ):
@@ -291,138 +329,205 @@ class TrainingStateAverager(DecentralizedAverager):
         Perform one or several possible actions, depending on the specified keyword args.
         The actions will be performed in the same order as specified below:
 
-        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :note: if specified, it is guaranteed that epoch is incremented immediately regardless of other options
         :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
         :param zero_grad: if True, reset local gradients after performing optimizer step
         :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
-        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
+        :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
+        :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
+        :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param averaging_opts: a dict of keyword arguments forwarded into averaging round
         """
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
-        if wait_for_delayed_update is None:
-            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
-        if optimizer_step or averaging_round or zero_grad:
-            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
+        if wait_for_trigger is not None:
+            assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
+            if not (self.reuse_tensors or self.custom_gradients):
+                # averager was asked to wait_for_trigger in background, but it is not clear which version of gradients
+                # should be used for optimizer step (e.g. the gradients that were present during the call to .step or
+                # the possibly different gradients when wait_for_trigger has finished).
+                raise ValueError(
+                    "wait_for_trigger is a low-level option that requires manual gradient manipulation. "
+                    "If you know what you're doing, please refer to the comments in the source code for details."
+                )
         output = None
 
-        if wait_for_delayed_update:
-            if not self.pending_update.done():
-                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                output = self.pending_update.result()
-
-        if self.pending_update.done() and self.pending_update.exception():
-            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result()
+                except BaseException:
+                    pass  # exception will be reported below
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
 
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
 
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
+                    self._apply_optimizer_parameters_()
+                logger.debug("Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
 
         if increment_epoch:
             self.local_epoch += 1
 
         if optimizer_step or zero_grad or averaging_round:
-            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
-
             if self.offload_optimizer and not self.custom_gradients:
                 self._load_local_grads_into_optimizer_()
 
-            self.pending_update = self.step_executor.submit(
+            pending_update = self.step_executor.submit(
                 self._do,
+                wait_for_trigger,
                 optimizer_step,
                 zero_grad,
                 averaging_round,
+                averaging_control,
                 grad_scaler,
                 **averaging_opts or {},
             )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
 
-            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+            if should_await_optimizer:
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
-                if self.offload_optimizer:
-                    self._apply_optimizer_results_()
-                logger.log(self.status_loglevel, "Finished optimizer step")
+                if self.offload_optimizer and not should_await_averaging:
+                    self._apply_optimizer_parameters_()
+                logger.debug("Finished optimizer step")
 
-            if averaging_round and not delay_averaging:
+            if should_await_averaging:
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.clear()
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished averaging round")
 
-            if not delay_averaging:
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
                 try:
-                    output = self.pending_update.result()
+                    output = pending_update.result()
                 finally:
-                    self.finished_averaging_round.clear()
-                    self.finished_optimizer_step.clear()
+                    self.pending_updates.remove(pending_update)
+
         return output
 
     def _do(
-        self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, grad_scaler: Optional[GradScaler], **kwargs
+        self,
+        wait_for_trigger: Optional[Callable[[], Any]],
+        optimizer_step: bool,
+        zero_grad: bool,
+        averaging_round: bool,
+        averaging_control: Optional[StepControl],
+        grad_scaler: Optional[GradScaler],
+        timeout: Optional[float] = None,
+        **kwargs,
     ):
         """
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         """
-        try:
-            if optimizer_step:
-                logger.log(self.status_loglevel, f"Running optimizer step")
-                if grad_scaler is None:
-                    self.optimizer.step()
-                else:
-                    with grad_scaler.running_global_step():
-                        assert grad_scaler.step(self.optimizer)
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
 
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
+        start_time = time.perf_counter()
+        began_running = False
 
-            self._update_scheduler()
-
-            if zero_grad:
-                logger.log(self.status_loglevel, f"Running zero grad")
-                self.optimizer.zero_grad()
-                if self.offload_optimizer:
-                    for parameter in self.main_parameters:
-                        if parameter.grad is not None:
-                            parameter.grad.zero_()
+        try:
+            if averaging_round and averaging_control is None:
+                averaging_control = super().step(
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
+                )
 
-            self.finished_optimizer_step.set()
+            if wait_for_trigger is not None:
+                wait_for_trigger()
+            began_running = True
+
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
+
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
 
             if averaging_round:
-                if not self.reuse_tensors:
-                    self._load_local_tensors_into_averager_()
-                try:
-                    gathered = super().step(gather=self.local_epoch, **kwargs)
-                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    self.finished_averaging_round.set()
-                    gathered = {}
+                with self.lock_averaging:
+                    if not self.reuse_tensors:
+                        self._load_local_tensors_into_averager_()
+                    if self.delta_rule_averaging:
+                        # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                        with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                            self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                    self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
+                    try:
+                        averaging_control.allow_allreduce()
+                        gathered = averaging_control.result(timeout=timeout)
+                        logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                    except BaseException as e:
+                        logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                        gathered = {}
 
-                self.finished_averaging_round.set()
+                    self.finished_averaging_round.set()
 
                 if self.sync_epoch_when_averaging:
                     old_epoch = self.local_epoch
@@ -433,7 +538,12 @@ class TrainingStateAverager(DecentralizedAverager):
                         self._update_scheduler()
 
         except Exception as e:
+            if not began_running:
+                logger.error(f"Aborted {self.__class__.__name__}.step because wait_for_trigger raised exception")
             logger.exception(e)
+            if averaging_control is not None and not averaging_control.done():
+                logger.error(f"Cancelled scheduled state averaging round")
+                averaging_control.cancel()
             self.finished_optimizer_step.set()
             self.finished_averaging_round.set()
 
@@ -447,16 +557,13 @@ class TrainingStateAverager(DecentralizedAverager):
                 opt_param.grad.copy_(main_param.grad, non_blocking=True)
 
     @torch.no_grad()
-    def _apply_optimizer_results_(self):
+    def _apply_optimizer_parameters_(self):
         """Copy parameters from offloaded optimizer to the main model"""
         assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
-        with self.lock_averaged_tensors:
-            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
-            assert len(offloaded_parameters) == len(
-                self.main_parameters
-            ), "Optimizer parameters changed during training"
-            for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
-                main_param.copy_(offloaded_param, non_blocking=True)
+        offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        assert len(offloaded_parameters) == len(self.main_parameters), "Optimizer parameters changed during training"
+        for main_param, offloaded_param in zip(self.main_parameters, offloaded_parameters):
+            main_param.copy_(offloaded_param, non_blocking=True)
 
     @torch.no_grad()
     def _load_local_tensors_into_averager_(self):
@@ -470,18 +577,30 @@ class TrainingStateAverager(DecentralizedAverager):
     def _apply_averaging_results_(self):
         """Copy averaged tensors into their respective local tensors"""
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed.")
         with self.get_tensors() as averaged_tensors:
             local_tensors = list(self._local_tensors())
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
-            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                local_tensor.copy_(averaged_tensor, non_blocking=True)
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
+    @property
+    def averaging_in_progress(self) -> bool:
+        return self.lock_averaging.locked()
 
     def get_current_state(self):
         """
         Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
         :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
         """
-        with torch.no_grad():
+        with torch.no_grad(), self.lock_averaged_tensors:
             optimized_parameters = tuple(
                 param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
             )
@@ -512,8 +631,8 @@ class TrainingStateAverager(DecentralizedAverager):
         Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
         :returns: whether or the averager succeeded in loading parameters
         """
-        parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
-        num_parameters_and_extras = len(parameters_and_extras)
+        main_parameters_and_extras = tuple(chain(self.main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(main_parameters_and_extras)
 
         loaded_state = super().load_state_from_peers(**kwargs)
         if loaded_state is None:
@@ -530,15 +649,19 @@ class TrainingStateAverager(DecentralizedAverager):
             logger.error("Failed to load state from peer, received parameters, extras or metadata.")
             return
 
-        try:
-            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
-        except StopIteration:
-            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
-            return
+        with torch.no_grad(), self.lock_averaged_tensors:
+            try:
+                load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+            except StopIteration:
+                logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+                return
 
-        with torch.no_grad():
-            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+            for local_param, loaded_param in zip(main_parameters_and_extras, loaded_parameters_and_extras):
                 local_param.copy_(loaded_param, non_blocking=True)
+
+        if self.offload_optimizer:
+            self._apply_optimizer_parameters_()
+
         self.local_epoch = metadata["epoch"]
         self._update_scheduler()
 

+ 60 - 36
hivemind/optim/grad_scaler.py

@@ -1,12 +1,14 @@
 import contextlib
+import threading
+from copy import deepcopy
 from typing import Dict, Optional
 
 import torch
 from torch.cuda.amp import GradScaler as TorchGradScaler
-from torch.cuda.amp.grad_scaler import _refresh_per_optimizer_state
+from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
 from torch.optim import Optimizer as TorchOptimizer
 
-from hivemind.optim.base import DecentralizedOptimizerBase
+import hivemind
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -14,7 +16,12 @@ logger = get_logger(__name__)
 
 class GradScaler(TorchGradScaler):
     """
-    A thin wrapper over pytorch GradScaler that supports hivemind-style training with CollaborativeOptimizer, namely:
+    A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.
+
+    :note: if not using reuse_grad_buffers=True, one can and *should* train normally without this class, e.g. using
+      standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.
+
+    GradScaler removes several:
     - bypass .unscale_ and .update calls in order to accumulate gradients over several steps
     - limit increasing gradient scale to only immediately after global optimizer steps
     - allow training with some or all master parameters in fp16
@@ -23,52 +30,68 @@ class GradScaler(TorchGradScaler):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self._is_running_global_step = False
+        self._is_ready_to_update = False
         self._optimizer_states_to_reset = set()
+        self._lock = threading.RLock()
 
     @contextlib.contextmanager
     def running_global_step(self):
-        was_running, self._is_running_global_step = self._is_running_global_step, True
-        try:
-            yield
-        finally:
-            self._is_running_global_step = was_running
+        with self._lock:
+            was_running, self._is_running_global_step = self._is_running_global_step, True
+            try:
+                yield
+            finally:
+                self._is_running_global_step = was_running
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
-        assert isinstance(optimizer, DecentralizedOptimizerBase)
-        if self._is_running_global_step:
-            super().unscale_(optimizer.opt)
-            return True
-        else:
-            self._check_inf_per_device(optimizer.opt)
-            self._optimizer_states_to_reset.add(id(optimizer))
-            return False
+        with self._lock:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            if self._is_running_global_step:
+                super().unscale_(optimizer)
+                self._per_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
+                return True
+            else:
+                self._check_inf_per_device(optimizer)
+                self._optimizer_states_to_reset.add(id(optimizer))
+                return False
 
     def step(self, optimizer: TorchOptimizer, *args, **kwargs) -> bool:
         if self._is_running_global_step:
-            if self.are_grads_finite(optimizer):
-                super().step(optimizer, *args, **kwargs)
-            else:
-                logger.warning("Skipping global step due to gradient over/underflow")
-            return True
+            with self._lock:
+                if self._is_ready_to_update:
+                    logger.warning("Please call grad_scaler.update() after each step.")
+                assert not isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+                assert (
+                    self._per_optimizer_states[id(optimizer)]["stage"] == OptState.UNSCALED
+                ), "InternalError: Optimizer should have called .unscale internally before invoking grad_scaler.step."
+                if self.are_grads_finite(optimizer, use_cached=True):
+                    super().step(optimizer, *args, **kwargs)
+                else:
+                    logger.warning("Skipping global step due to gradient over/underflow")
+                self._is_ready_to_update = True
+                return True
         else:
+            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
             super().step(optimizer)
             self._optimizer_states_to_reset.add(id(optimizer))
             return False
 
     def update(self, new_scale: Optional[float] = None) -> bool:
-        total_infs = 0
-        for optimizer_state in self._per_optimizer_states.values():
-            total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
-
-        if self._is_running_global_step or total_infs != 0:
-            # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
-            super().update(new_scale)
-            return True
-        else:
-            for opt_id in self._optimizer_states_to_reset:
-                self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
-            self._optimizer_states_to_reset.clear()
-            return False
+        with self._lock:
+            total_infs = 0
+            for optimizer_state in self._per_optimizer_states.values():
+                total_infs += sum(v.item() for v in optimizer_state["found_inf_per_device"].values())
+
+            if self._is_ready_to_update or total_infs != 0:
+                # note: we update either during actual optimizer step or if we need to reduce scale due to NaN
+                super().update(new_scale)
+                self._is_ready_to_update = False
+                return True
+            else:
+                for opt_id in self._optimizer_states_to_reset:
+                    self._per_optimizer_states[opt_id] = _refresh_per_optimizer_state()
+                self._optimizer_states_to_reset.clear()
+                return False
 
     def _unscale_grads_(
         self, optimizer: TorchOptimizer, inv_scale: torch.Tensor, found_inf: torch.Tensor, allow_fp16: bool
@@ -77,8 +100,9 @@ class GradScaler(TorchGradScaler):
         # inspired by: https://github.com/facebookresearch/fairscale/blob/945b9666/fairscale/optim/grad_scaler.py
         return super()._unscale_grads_(optimizer, inv_scale, found_inf, allow_fp16=True)
 
-    def are_grads_finite(self, optimizer: TorchOptimizer) -> bool:
-        return not sum(v.item() for v in self._check_inf_per_device(optimizer).values())
+    def are_grads_finite(self, optimizer: TorchOptimizer, use_cached: bool = False) -> bool:
+        opt_dict = self._found_inf_per_device(optimizer) if use_cached else self._check_inf_per_device(optimizer)
+        return not sum(v.item() for v in opt_dict.values())
 
 
 class HivemindGradScaler(GradScaler):

+ 1 - 0
hivemind/utils/__init__.py

@@ -5,6 +5,7 @@ from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.mpfuture import *
 from hivemind.utils.nested import *
 from hivemind.utils.networking import *
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.timed_storage import *

+ 1 - 0
requirements-dev.txt

@@ -4,6 +4,7 @@ pytest-asyncio
 pytest-cov
 tqdm
 scikit-learn
+torchvision
 black==21.6b0
 isort
 psutil

+ 1 - 1
tests/test_averaging.py

@@ -528,7 +528,7 @@ def test_averaging_cancel():
 
     step_controls = [averager.step(wait=False, scheduled_time=hivemind.get_dht_time() + 1) for averager in averagers]
 
-    time.sleep(0.2)
+    time.sleep(0.1)
     step_controls[0].cancel()
     step_controls[1].cancel()
 

+ 120 - 16
tests/test_optimizer.py

@@ -12,6 +12,7 @@ import torch.nn.functional as F
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.optimizer import Optimizer
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
@@ -78,7 +79,7 @@ def test_grad_averager():
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
-    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
 )
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
     dht1 = hivemind.DHT(start=True)
@@ -106,10 +107,10 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
     )
     avgr2 = TrainingStateAverager(
-        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+        dht=dht2, params=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
     )
 
     x = torch.ones(2)
@@ -135,8 +136,8 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
 
-    avgr1.step(wait_for_delayed_update=True)
-    avgr2.step(wait_for_delayed_update=True)
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
 
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
@@ -161,10 +162,10 @@ def test_load_state_from_peers():
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
     )
 
-    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+    avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
 
     avgr2.local_epoch = 1337
     model2.weight.data[...] = 42
@@ -217,14 +218,14 @@ def test_progress_tracker():
             tracker.report_local_progress(local_epoch, samples_accumulated)
 
             if tracker.ready_to_update_epoch:
+                if index == 4 and local_epoch >= 4:
+                    time.sleep(0.5)
+                    break
+
                 with tracker.pause_updates():
                     local_epoch = tracker.update_epoch(local_epoch + 1)
                     samples_accumulated = 0
 
-                if index == 4 and local_epoch >= 5:
-                    time.sleep(0.5)
-                    break
-
         emas[index] = tracker.performance_ema.samples_per_second
         tracker.shutdown()
         dht.shutdown()
@@ -249,16 +250,19 @@ def test_progress_tracker():
     )
     barrier.wait()
 
-    current_step = 0
+    local_epoch = 0
     last_timestamp = hivemind.get_dht_time()
     step_time_deltas = []
 
-    while current_step < 6:
+    while local_epoch < 6:
         time.sleep(0.1)
-        if tracker.global_progress.epoch > current_step:
+
+        if tracker.ready_to_update_epoch:
+            with tracker.pause_updates():
+                local_epoch = tracker.update_epoch(local_epoch + 1)
+
             time_delta = hivemind.get_dht_time() - last_timestamp
-            current_step = tracker.global_progress.epoch
-            if current_step == 2:
+            if local_epoch == 2:
                 delayed_start_evt.set()
 
             last_timestamp = hivemind.get_dht_time()
@@ -279,3 +283,103 @@ def test_progress_tracker():
         assert 0.5 * mean_step_time < step_time_deltas[i] < 0.95 * mean_step_time
     assert emas[1] < emas[2] < emas[3] < emas[4]
     assert tracker.performance_ema.samples_per_second < 1e-9
+
+
+@pytest.mark.forked
+def test_optimizer(
+    num_peers: int = 1,
+    num_clients: int = 0,
+    target_batch_size: int = 32,
+    total_epochs: int = 3,
+    reuse_grad_buffers: bool = True,
+    delay_grad_averaging: bool = True,
+    delay_optimizer_step: bool = True,
+    average_state_every: int = 1,
+):
+    dht = hivemind.DHT(start=True)
+
+    features = torch.randn(100, 5)
+    targets = features @ torch.randn(5, 1)
+    optimizer = None
+    total_samples_accumulated = mp.Value(ctypes.c_int32, 0)
+
+    def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
+        nonlocal optimizer
+        model = nn.Linear(5, 1)
+
+        assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"
+
+        optimizer = Optimizer(
+            run_id="test_run",
+            target_batch_size=target_batch_size,
+            batch_size_per_step=batch_size,
+            params=model.parameters(),
+            optimizer=partial(torch.optim.SGD, lr=0.1),
+            scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=0.5, step_size=1),
+            dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
+            tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
+            averager_opts=dict(request_timeout=0.5),
+            matchmaking_time=1.0,
+            averaging_timeout=5.0,
+            reuse_grad_buffers=reuse_grad_buffers,
+            delay_grad_averaging=delay_grad_averaging,
+            delay_optimizer_step=delay_optimizer_step,
+            average_state_every=average_state_every,
+            client_mode=client_mode,
+            verbose=False,
+        )
+        optimizer.load_state_from_peers()
+
+        prev_time = time.perf_counter()
+
+        while optimizer.local_epoch < total_epochs:
+            time.sleep(max(0.0, prev_time + batch_time - time.perf_counter()))
+            batch = torch.randint(0, len(features), (batch_size,))
+
+            loss = F.mse_loss(model(features[batch]), targets[batch])
+            loss.backward()
+
+            optimizer.step()
+
+            total_samples_accumulated.value += batch_size
+
+            if not reuse_grad_buffers:
+                optimizer.zero_grad()
+
+            prev_time = time.perf_counter()
+
+        time.sleep(1.0)
+        optimizer.shutdown()
+        return optimizer
+
+    peers = []
+
+    for index in range(num_peers):
+        peers.append(
+            mp.Process(
+                target=run_trainer,
+                name=f"trainer-{index}",
+                kwargs=dict(
+                    batch_size=4 + index,
+                    batch_time=0.3 + 0.2 * index,
+                    client_mode=(index >= num_peers - num_clients),
+                ),
+            )
+        )
+
+    for peer in peers[1:]:
+        peer.start()
+    peers[0].run()
+    for peer in peers[1:]:
+        peer.join()
+
+    assert isinstance(optimizer, Optimizer)
+    assert optimizer.local_epoch == optimizer.tracker.global_epoch == total_epochs
+    expected_samples_accumulated = target_batch_size * total_epochs
+    assert expected_samples_accumulated <= total_samples_accumulated.value <= expected_samples_accumulated * 1.2
+    assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
+
+    assert not optimizer.state_averager.is_alive()
+    assert not optimizer.grad_averager.is_alive()
+    assert not optimizer.tracker.is_alive()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()