Browse Source

local updates and docstring

justheuristic 3 years ago
parent
commit
2a15c37f1b

+ 343 - 147
hivemind/optim/experimental/optimizer.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import logging
 import os
 from functools import partial
-from typing import Callable, Optional, Union
+from typing import Callable, Optional, Sequence, Union
 
 import torch
 
 from hivemind.averaging.control import StepControl
+from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
@@ -21,22 +22,25 @@ from hivemind.optim.experimental.state_averager import (
     TrainingStateAverager,
 )
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger, DHTExpiration
 
 logger = get_logger(__name__)
 
 
 class Optimizer(torch.optim.Optimizer):
     """
-    Hivemind Optimizer wraps your regular PyTorch Optimizer for training in a swarm of peers. It can be configured with
-     synchronous, delayed or asynchronous updates to trade between optimization guarantees and compute utilization.
+    Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+     or even fully asynchronous (local_updates=True). However, these options require careful tuning.
 
-    The Optimizer is meant as a drop-in replacement for your regular PyTorch code:
+    The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
 
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
-    >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, run_id="run_42",
-    >>>                          target_batch_size=4096, batch_size_per_step=4)
+    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam, params=model.parameters(),
+    >>>                          target_batch_size=4096, batch_size_per_step=4)  # recommended way to create Optimizer
+    >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
     >>> while True:
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
     >>>     opt.zero_grad()
@@ -44,33 +48,69 @@ class Optimizer(torch.optim.Optimizer):
     >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
 
     However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
-    - accumulate a minibatch of data towards the (global) target batch size without changing parameters (yet),
-    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step,
-    - if, for any reason, your peer lags behind the rest of the swarm, it will load state from up-to-date peers.
+    - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
+    - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
 
     :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
-      learning rate schedulers, curriculum and other time-dependent features should use opt.global_step (and not the
-      number of local forward-backward cycles). This is because any device can join midway through training, when
-      other peers have already made some progress and changed their learning rate accordingly.
+      learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
+      (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
+      when others have already made some progress and changed their learning rates accordingly.
 
     :param dht: a running hivemind.DHT instance connected to other peers
-    :param run_id: a unique identifier of this experiment, used as a common prefix for all DHT keys
-    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys
+
+    :note: peers with the same run_id should *generally* train the same model and use the same optimizer configuration.
+      Some options can be safely changed by individual peers: `batch_size_per_step`, `client_mode`, `auxiliary`,
+      `reuse_grad_buffers`, `offload_optimizer`, and `verbose`. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
-    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
-    :param scheduler: if specified, use this scheduler to update optimizer learning rate
-    :note: If you are using hivemind.Optimizer with lr_scheduler, it is recommended to pass this scheduler
-      explicitly into this class. Otherwise, it may become non-synchronized between peers.
+
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
+    :note: creating hivemind.Optimizer with params=model.parameters() and optimizer=lambda params: make_optim(params)
+      is required for advanced options: offload_optimizer, delay_optimizer_step and delay_grad_averaging.
+
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler
+    :note: the learning rate scheduler will adjust learning rate based on collaboration-wide epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
 
     :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
-    :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many epochs
-    :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
-    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+    :note: offload_optimizer, delay_optimizer_step and delay_grad_averaging require that the optimizer is
+      created as follows: `hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())`
+
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+    :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many **epochs**
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large
+    :note: The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. Network failures, lossy gradient
+      compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients
+    :note: even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+    :note: client_mode=True and auxiliary=True are mutually exclusive; auxiliary also requires batch_size_per_step=None
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
@@ -97,12 +137,20 @@ class Optimizer(torch.optim.Optimizer):
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 300.0,
         load_state_timeout: float = 600.0,
-        average_state_every: int = 1,
         reuse_grad_buffers: bool = False,
-        delay_grad_averaging: bool = False,
+        offload_optimizer: Optional[bool] = None,
         delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
         client_mode: bool = None,
         auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         shutdown_timeout: float = 5,
@@ -110,31 +158,65 @@ class Optimizer(torch.optim.Optimizer):
     ):
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
-        self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
-        self.scheduled_round: Optional[StepControl] = None
-        self.previous_round: Optional[StepControl] = None
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
 
+        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
         self.state_averager = self._make_state_averager(
-            optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
         )
-        self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
-        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
         self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._schema_hash = self._compute_schema_hash()
         self._parent_pid = os.getpid()
 
-        self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
+        self._step_supports_amp_scaling = reuse_grad_buffers
         # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
         # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
 
@@ -144,11 +226,11 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_state_averager",
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
             status_loglevel=self.status_loglevel,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
-            offload_optimizer=True,
-            custom_gradients=True,
             start=True,
             **kwargs,
         )
@@ -166,12 +248,13 @@ class Optimizer(torch.optim.Optimizer):
             start=True,
             **kwargs,
         )
-        optimized_param_groups = self.state_averager.optimizer.param_groups
-        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
-        with grad_averager.get_tensors() as averaged_gradients:
-            assert len(averaged_gradients) == len(optimized_parameters)
-            for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
-                opt_param.grad = averaged_grad
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
         return grad_averager
 
     def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
@@ -189,7 +272,9 @@ class Optimizer(torch.optim.Optimizer):
         optimized_param_groups = self.state_averager.optimizer.param_groups
         optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
         param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
-        grad_ids = tuple(id(param.grad) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
         return hash((grad_ids, param_shapes))
 
     def is_alive(self) -> bool:
@@ -199,25 +284,13 @@ class Optimizer(torch.optim.Optimizer):
     def local_epoch(self) -> int:
         return self.state_averager.local_epoch
 
-    def should_load_state_from_peers(self) -> bool:
-        """
-        If true, peer will discard local progress and attempt to download state from peers.
-        This method allows peer to continue training in two cases:
-         - peer is on the same epoch as other collaborators - keep training normally
-         - peer was on the same epoch and accumulated some grads, but some collaborators
-             have just transitioned to the next epoch - this peer should also transition.
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
 
-        :note: The latter case occurs due to the lack of network synchrony: the first peer that
-        detects enough samples will transition to the next step and start counting samples anew.
-        Some other peers may take time before they check with DHT and observe that
-          - the global epoch is technically one epoch ahead of the current one and
-          - the remaining (non-transitioned) peers no longer have target_batch_size between them
-        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
-        """
-        if self._should_check_synchronization_on_update and self.tracker.updated_progress_this_epoch.is_set():
-            self._should_check_synchronization_on_update = False
-            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
-        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
 
     def step(
         self,
@@ -241,6 +314,9 @@ class Optimizer(torch.optim.Optimizer):
             raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
         loss = None
         if closure is not None:
             with torch.enable_grad():
@@ -249,111 +325,180 @@ class Optimizer(torch.optim.Optimizer):
         if not self.auxiliary and self.should_load_state_from_peers():
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
-            return loss
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
-            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
-            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
-            self.grad_averager.reset_accumulated_grads_()
-            return loss
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
 
-        if not self.auxiliary:
-            self.grad_averager.accumulate_grads_(batch_size)
-            self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
-            self.state_averager.step(apply_delayed_updates=True)
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
 
-        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
-            if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
-                if self.delay_grad_averaging:
-                    # wait for previous averaging to finish before starting a new one
-                    self.state_averager.step(wait_for_delayed_update=True)
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
 
-                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
-                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
-                logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
-                scheduled_time = self.tracker.estimated_next_update_time
-                if self.client_mode:
-                    scheduled_time = get_dht_time() + self.averaging_timeout
-                self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
 
-        if not self.tracker.ready_to_update_epoch:
-            return loss
+        return loss
 
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
         assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
 
         with self.tracker.pause_updates():
-            # note: we do not need to replace grads because we explicitly load grads into the optimizer
-
-            logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self)
-
-            if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
-                logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
-                self.scheduled_round = None
+            wait_for_trigger = None
 
-            swarm_not_empty = self.tracker.global_progress.num_peers > 1
-            began_averaging_gradients = False
-            if swarm_not_empty:
-                try:
-                    self.scheduled_round = self.grad_averager.step(
-                        control=self.scheduled_round, reset_accumulators=True, wait=False
-                    )
-                    assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
-                    began_averaging_gradients = True
-                except BaseException as e:
-                    logger.exception(e)
-
-            if not began_averaging_gradients and self.scheduled_round is not None and not self.scheduled_round.done():
-                logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
-                self.scheduled_round.cancel()
-                self.scheduled_round = None
-
-            if not self.delay_grad_averaging:
-                self._average_gradients_and_load_into_optimizer(self.scheduled_round)
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    return  # failed to start gradient averaging due to an internal error
+                if self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
 
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
 
             self.state_averager.step(
                 increment_epoch=True,
-                optimizer_step=not self.auxiliary,
-                delay_optimizer_step=self.delay_optimizer_step,
-                averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
-                delay_averaging=not self.auxiliary,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
                 grad_scaler=grad_scaler,
-                wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
-                if self.delay_grad_averaging
-                else None,
-                averaging_opts=dict(
-                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
-                )
-                if swarm_not_empty and next_epoch % self.average_state_every == 0
-                else None,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
 
-            if not self.auxiliary:
-                self.grad_averager.reset_accumulated_grads_()
-                self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
-                self._should_check_synchronization_on_update = True
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self.scheduled_grads = self.scheduled_state = None
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
 
             if not self.client_mode:
-                self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
 
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
             logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
-        return loss
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_grads}")
+            self.scheduled_grads = None
+
+        began_averaging_gradients = False
+        if self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
+            self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                if self.delay_grad_averaging:
+                    # wait for previous averaging to finish before starting a new one
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
+                scheduled_time = self.tracker.estimated_next_update_time
+                if self.client_mode:
+                    scheduled_time = get_dht_time() + self.averaging_timeout
+                self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = self.tracker.estimated_next_update_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                if self.client_mode:
+                    estimated_time = get_dht_time() + self.averaging_timeout
+                self.scheduled_state = self.state_averager.schedule_step(
+                    estimated_time, gather=next_epoch, timeout=self.averaging_timeout
+                )
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
         """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
-        assert maybe_step_control is None or maybe_step_control.triggered
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
         averaged_gradients = False
 
         try:
             if maybe_step_control is not None:
                 group_info = maybe_step_control.result(self.averaging_timeout)
                 logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
                 averaged_gradients = True
             else:
                 logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
@@ -363,28 +508,65 @@ class Optimizer(torch.optim.Optimizer):
         if not averaged_gradients:
             logger.log(self.status_loglevel, f"Proceeding with local gradients")
             self.grad_averager.load_accumulators_into_averager_()
+            self._load_averaged_gradients_into_optimizer_()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
 
         self.grad_averager.notify_used_averaged_gradients()
 
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
-        if self.grad_averager.reuse_grad_buffers:
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
             raise ValueError(
                 f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
                 f"call zero_grad manually. Gradients will be refreshed internally."
             )
-        for param in self.grad_averager.parameters:
-            if param.grad is None:
-                pass
-            elif set_to_none:
-                param.grad = None
-            else:
-                param.grad.zero_()
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
-        if self.scheduled_round is not None and not self.scheduled_round.done():
-            self.scheduled_round.cancel()
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self.scheduled_grads.cancel()
 
         with self.tracker.pause_updates():
             while True:
@@ -402,11 +584,23 @@ class Optimizer(torch.optim.Optimizer):
                 self.state_averager.local_epoch = self.tracker.global_epoch
 
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
-            self.grad_averager.reset_accumulated_grads_()
+
             if not self.client_mode:
-                self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
 
+            self._cancel_scheduled_averaging()
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def _cancel_scheduled_averaging(self):
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self.scheduled_grads.cancel()
+        if self.scheduled_state is not None and not self.scheduled_state.done():
+            self.scheduled_state.cancel()
+
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict["state"]["local_epoch"] = self.local_epoch
@@ -448,11 +642,13 @@ class Optimizer(torch.optim.Optimizer):
 
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
+        self._cancel_scheduled_averaging()
         self.tracker.shutdown(self.shutdown_timeout)
-        logger.debug("Shutting down averager...")
-        self.state_averager.step(wait_for_delayed_update=True)
+        logger.debug("Shutting down averagers...")
+        self.state_averager.step(wait_for_delayed_updates=True)
         self.state_averager.shutdown()
-        self.grad_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
         logger.debug(f"{self.__class__.__name__} is shut down.")
 
     def __del__(self):

+ 10 - 5
hivemind/optim/experimental/progress_tracker.py

@@ -114,7 +114,7 @@ class ProgressTracker(threading.Thread):
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
-        self.should_report_progress, self.updated_progress_this_epoch = threading.Event(), threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         if start:
@@ -150,15 +150,20 @@ class ProgressTracker(threading.Thread):
             client_mode=self.client_mode,
         )
 
-    def report_local_progress(self, local_epoch: int, samples_accumulated: int):
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
         """Update the number of locally accumulated samples and notify to other peers about this."""
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
         if extra_samples > 0:
             self.performance_ema.update(task_size=extra_samples)
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
         else:
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             self.performance_ema.reset_timer()
+
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.should_report_progress.set()
 
@@ -178,7 +183,7 @@ class ProgressTracker(threading.Thread):
             self.global_progress.samples_accumulated = 0
             self.global_progress.eta_next_epoch = float("inf")
         self.report_local_progress(new_epoch, samples_accumulated=0)
-        self.updated_progress_this_epoch.clear()
+        self.fetched_global_progress_this_epoch.clear()
         return new_epoch
 
     def run(self):
@@ -257,7 +262,7 @@ class ProgressTracker(threading.Thread):
                         break
                     metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
-                    self.updated_progress_this_epoch.set()
+                    self.fetched_global_progress_this_epoch.set()
 
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
@@ -321,7 +326,7 @@ class ProgressTracker(threading.Thread):
         )
         logger.log(
             self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         return GlobalTrainingProgress(

+ 149 - 75
hivemind/optim/experimental/state_averager.py

@@ -1,19 +1,20 @@
 """ An extension of averager that supports common optimization use cases. """
 import logging
-from asyncio import Future
+import time
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from itertools import chain
-from threading import Event
+import threading
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 import torch
 
 import hivemind
 from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_logger, nested_flatten, nested_pack
+from hivemind.utils import get_logger, nested_flatten, nested_pack, get_dht_time, DHTExpiration, PerformanceEMA
 
 logger = get_logger(__name__)
 
@@ -61,6 +62,9 @@ class TrainingStateAverager(DecentralizedAverager):
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
     :param parameter_names: optionally provide parameter names in the same order as in params
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -79,7 +83,9 @@ class TrainingStateAverager(DecentralizedAverager):
         initialize_optimizer: Optional[bool] = None,
         offload_optimizer: bool = False,
         custom_gradients: bool = False,
-        reuse_tensors: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
         sync_epoch_when_averaging: bool = False,
         parameter_names: Optional[Sequence[str]] = None,
         average_opt_statistics: Sequence[str] = (),
@@ -89,17 +95,19 @@ class TrainingStateAverager(DecentralizedAverager):
     ):
         average_opt_statistics = tuple(average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
-        if offload_optimizer and reuse_tensors:
-            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
         if custom_gradients and not offload_optimizer:
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
 
         param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
 
         self.status_loglevel = status_loglevel
-        self.reuse_tensors = reuse_tensors
-        self.offload_optimizer = offload_optimizer
-        self.custom_gradients = custom_gradients
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
@@ -110,11 +118,12 @@ class TrainingStateAverager(DecentralizedAverager):
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.local_epoch = 0
 
-        self.step_executor = ThreadPoolExecutor(max_workers=1)
-        self.finished_optimizer_step = Event()
-        self.finished_averaging_round = Event()
-        self.pending_update = Future()
-        self.pending_update.set_result(None)
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.pending_updates = set()
 
         super().__init__(
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
@@ -147,7 +156,8 @@ class TrainingStateAverager(DecentralizedAverager):
     def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
         if self.reuse_tensors:
-            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():
                 source_tensor.share_memory_()
             return source_tensor
@@ -174,19 +184,26 @@ class TrainingStateAverager(DecentralizedAverager):
         # create optimizer
         if optimizer_is_factory:
             if self.offload_optimizer:
-                for param in self._averaged_parameters:
-                    if param.grad is None:
-                        param.grad = torch.zeros_like(param)
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
 
                 next_index = 0
                 param_groups_for_optimizer = []
                 for param_group in param_groups:
                     num_params = len(param_group["params"])
-                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     next_index += num_params
-                assert next_index == len(self._averaged_parameters)
+                assert next_index == len(parameters_for_optimizer)
 
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
             else:
                 param_groups_for_optimizer = param_groups
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
@@ -214,7 +231,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
         # verify optimizer and scheduler
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
-        if self.offload_optimizer or self.reuse_tensors:
+        if self.reuse_tensors:
             for param_group in optimizer.param_groups:
                 for param in param_group["params"]:
                     assert param.is_shared()
@@ -237,7 +254,7 @@ class TrainingStateAverager(DecentralizedAverager):
     def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
         """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
         assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
-        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert hasattr(self, "_averaged_parameters"), "ShoTrueuld initialize _averaged_parameters first"
         assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
         assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
 
@@ -251,7 +268,7 @@ class TrainingStateAverager(DecentralizedAverager):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
             assert local_tensor.shape == averaged_tensor.shape
             if averaged_tensor.grad is not None:
-                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
 
         return averaged_tensors
 
@@ -275,9 +292,22 @@ class TrainingStateAverager(DecentralizedAverager):
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
         return tuple(tensor_infos)
 
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
     def step(
         self,
-        wait_for_delayed_update: bool = None,
+        wait_for_delayed_updates: bool = None,
         apply_delayed_updates: bool = True,
         increment_epoch: bool = False,
         optimizer_step: bool = False,
@@ -285,6 +315,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
@@ -293,7 +324,7 @@ class TrainingStateAverager(DecentralizedAverager):
         Perform one or several possible actions, depending on the specified keyword args.
         The actions will be performed in the same order as specified below:
 
-        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
@@ -304,6 +335,7 @@ class TrainingStateAverager(DecentralizedAverager):
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
@@ -311,16 +343,19 @@ class TrainingStateAverager(DecentralizedAverager):
         """
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
-        if wait_for_delayed_update is None:
-            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
-        if optimizer_step or averaging_round or zero_grad:
-            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
         if wait_for_trigger is not None:
             assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
             if not (self.reuse_tensors or self.custom_gradients):
@@ -333,68 +368,83 @@ class TrainingStateAverager(DecentralizedAverager):
                 )
         output = None
 
-        if wait_for_delayed_update:
-            if not self.pending_update.done():
-                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                output = self.pending_update.result()
-
-        if self.pending_update.done() and self.pending_update.exception():
-            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result()
+                except BaseException:
+                    pass  # exception will be reported below
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
 
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
-                    if self.offload_optimizer:
-                        self._apply_optimizer_parameters_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
 
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
                     self._apply_optimizer_parameters_()
-                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
+                logger.debug("Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
 
         if increment_epoch:
             self.local_epoch += 1
 
         if optimizer_step or zero_grad or averaging_round:
-            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
-
             if self.offload_optimizer and not self.custom_gradients:
                 self._load_local_grads_into_optimizer_()
 
-            self.pending_update = self.step_executor.submit(
+            pending_update = self.step_executor.submit(
                 self._do,
                 wait_for_trigger,
                 optimizer_step,
                 zero_grad,
                 averaging_round,
+                averaging_control,
                 grad_scaler,
                 **averaging_opts or {},
             )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
 
-            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+            if should_await_optimizer:
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
-                if self.offload_optimizer:
+                if self.offload_optimizer and not should_await_averaging:
                     self._apply_optimizer_parameters_()
-                logger.log(self.status_loglevel, "Finished optimizer step")
+                logger.debug("Finished optimizer step")
 
-            if averaging_round and not delay_averaging:
+            if should_await_averaging:
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.clear()
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished averaging round")
 
-            if not delay_averaging:
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
                 try:
-                    output = self.pending_update.result()
+                    output = pending_update.result()
                 finally:
-                    self.finished_averaging_round.clear()
-                    self.finished_optimizer_step.clear()
+                    self.pending_updates.remove(pending_update)
+
         return output
 
     def _do(
@@ -403,6 +453,7 @@ class TrainingStateAverager(DecentralizedAverager):
         optimizer_step: bool,
         zero_grad: bool,
         averaging_round: bool,
+        averaging_control: Optional[StepControl],
         grad_scaler: Optional[GradScaler],
         timeout: Optional[float] = None,
         **kwargs,
@@ -411,50 +462,64 @@ class TrainingStateAverager(DecentralizedAverager):
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         """
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
+
+        start_time = time.perf_counter()
         began_running = False
-        averaging_control = None
 
         try:
-            if averaging_round:
+            if averaging_round and averaging_control is None:
                 averaging_control = super().step(
-                    gather=self.local_epoch, require_trigger=True, timeout=timeout, wait=False, **kwargs
+                    scheduled_time=get_dht_time() + self.delay_before_averaging.ema_seconds_per_sample,
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
                 )
 
             if wait_for_trigger is not None:
                 wait_for_trigger()
             began_running = True
 
-            if optimizer_step:
-                with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
-                    logger.log(self.status_loglevel, f"Running optimizer step")
-                    if grad_scaler is None:
-                        self.optimizer.step()
-                    else:
-                        with grad_scaler.running_global_step():
-                            assert grad_scaler.step(self.optimizer)
-
-            self._update_scheduler()
-
-            if zero_grad:
-                logger.log(self.status_loglevel, f"Running zero grad")
-                self.optimizer.zero_grad()
-                if self.offload_optimizer:
-                    for parameter in self.main_parameters:
-                        if parameter.grad is not None:
-                            parameter.grad.zero_()
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
 
-            self.finished_optimizer_step.set()
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
 
             if averaging_round:
                 if not self.reuse_tensors:
                     self._load_local_tensors_into_averager_()
+                if self.delta_rule_averaging:
+                    # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                    with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                        self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
                 try:
                     averaging_control.allow_allreduce()
                     gathered = averaging_control.result(timeout=timeout)
                     logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    self.finished_averaging_round.set()
                     gathered = {}
 
                 self.finished_averaging_round.set()
@@ -510,11 +575,20 @@ class TrainingStateAverager(DecentralizedAverager):
     def _apply_averaging_results_(self):
         """Copy averaged tensors into their respective local tensors"""
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed.")
         with self.get_tensors() as averaged_tensors:
             local_tensors = list(self._local_tensors())
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
-            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                local_tensor.copy_(averaged_tensor, non_blocking=True)
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
 
     def get_current_state(self):
         """

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,4 +7,5 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.timed_storage import *

+ 4 - 5
tests/test_optimizer.py

@@ -1,6 +1,5 @@
 import ctypes
 import multiprocessing as mp
-import random
 import time
 from functools import partial
 
@@ -80,7 +79,7 @@ def test_grad_averager():
 @pytest.mark.forked
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
-    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
 )
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
     dht1 = hivemind.DHT(start=True)
@@ -137,8 +136,8 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
 
-    avgr1.step(wait_for_delayed_update=True)
-    avgr2.step(wait_for_delayed_update=True)
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
 
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
@@ -383,4 +382,4 @@ def test_optimizer(
     assert not optimizer.state_averager.is_alive()
     assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
-    assert optimizer.scheduled_round is None or optimizer.scheduled_round.done()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()