瀏覽代碼

local updates and docstring

justheuristic 3 年之前
父節點
當前提交
2a15c37f1b

+ 343 - 147
hivemind/optim/experimental/optimizer.py

@@ -3,11 +3,12 @@ from __future__ import annotations
 import logging
 import logging
 import os
 import os
 from functools import partial
 from functools import partial
-from typing import Callable, Optional, Union
+from typing import Callable, Optional, Sequence, Union
 
 
 import torch
 import torch
 
 
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.control import StepControl
+from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.grad_averager import GradientAverager
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
 from hivemind.optim.experimental.progress_tracker import ProgressTracker
@@ -21,22 +22,25 @@ from hivemind.optim.experimental.state_averager import (
     TrainingStateAverager,
     TrainingStateAverager,
 )
 )
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_dht_time, get_logger
+from hivemind.utils import get_dht_time, get_logger, DHTExpiration
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
 class Optimizer(torch.optim.Optimizer):
 class Optimizer(torch.optim.Optimizer):
     """
     """
-    Hivemind Optimizer wraps your regular PyTorch Optimizer for training in a swarm of peers. It can be configured with
-     synchronous, delayed or asynchronous updates to trade between optimization guarantees and compute utilization.
+    Hivemind Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.
+    By default, Optimizer is configured to be exactly **equivalent to synchronous training** with target_batch_size;
+    There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging)
+     or even fully asynchronous (local_updates=True). However, these options require careful tuning.
 
 
-    The Optimizer is meant as a drop-in replacement for your regular PyTorch code:
+    The Optimizer is meant as a drop-in replacement for your regular PyTorch Optimizer:
 
 
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> model = transformers.AutoModel("albert-xxlarge-v2")
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
     >>> dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True)
-    >>> opt = hivemind.Optimizer(model.parameters(), optim_cls=torch.optim.Adam, run_id="run_42",
-    >>>                          target_batch_size=4096, batch_size_per_step=4)
+    >>> opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam, params=model.parameters(),
+    >>>                          target_batch_size=4096, batch_size_per_step=4)  # recommended way to create Optimizer
+    >>> # alternative: opt = hivemind.Optimizer(dht, run_id="run_42", optimizer=torch.optim.Adam(model.parameters())
     >>> while True:
     >>> while True:
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
     >>>     loss = compute_loss_on_batch(model, batch_size=4)
     >>>     opt.zero_grad()
     >>>     opt.zero_grad()
@@ -44,33 +48,69 @@ class Optimizer(torch.optim.Optimizer):
     >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
     >>>     opt.step()  # <-- train collaboratively with any peers that use the same prefix (run_42)
 
 
     However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
     However, unlike regular optimizers, calling opt.step with hivemind.Optimizer can do one of the following:
-    - accumulate a minibatch of data towards the (global) target batch size without changing parameters (yet),
-    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step,
-    - if, for any reason, your peer lags behind the rest of the swarm, it will load state from up-to-date peers.
+    - accumulate a minibatch of gradients towards the (global) target batch size, without updating parameters yet;
+    - after accumulating the target batch size, all-reduce gradients with peers and perform optimizer step;
+    - if your peer lags behind the rest of the swarm, it will download latest state from other peers;
 
 
     :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
     :note: hivemind.Optimizer can be used the same way any other pytorch optimizer, but there is one limitation:
-      learning rate schedulers, curriculum and other time-dependent features should use opt.global_step (and not the
-      number of local forward-backward cycles). This is because any device can join midway through training, when
-      other peers have already made some progress and changed their learning rate accordingly.
+      learning rate schedulers, curriculum and other **time-dependent features should depend on Optimizer.local_epoch**
+      (and not the number ot calls to opt.step). This is because peers are allowed to join midway through training,
+      when others have already made some progress and changed their learning rates accordingly.
 
 
     :param dht: a running hivemind.DHT instance connected to other peers
     :param dht: a running hivemind.DHT instance connected to other peers
-    :param run_id: a unique identifier of this experiment, used as a common prefix for all DHT keys
-    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
+    :param run_id: a unique identifier of this training run, used as a common prefix for all DHT keys
+
+    :note: peers with the same run_id should *generally* train the same model and use the same optimizer configuration.
+      Some options can be safely changed by individual peers: `batch_size_per_step`, `client_mode`, `auxiliary`,
+      `reuse_grad_buffers`, `offload_optimizer`, and `verbose`. In some cases, other options may also be tuned
+      individually by each peer, but they should be changed with caution to avoid deadlocks or convergence issues.
+
+    :param target_batch_size: global batch size that must be accumulated before the swarm transitions to the next epoch
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
     :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param optimizer: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
-    :param params: optional, a list/tuple of parameters or structured param groups for the optimizer
-    :param scheduler: if specified, use this scheduler to update optimizer learning rate
-    :note: If you are using hivemind.Optimizer with lr_scheduler, it is recommended to pass this scheduler
-      explicitly into this class. Otherwise, it may become non-synchronized between peers.
+
+    :param optimizer: a callable(parameters) -> pytorch.optim.Optimizer or a pre-initialized PyTorch optimizer
+    :param params: parameters or param groups for the optimizer; required if optimizer is a callable(params)
+    :note: creating hivemind.Optimizer with params=model.parameters() and optimizer=lambda params: make_optim(params)
+      is required for advanced options: offload_optimizer, delay_optimizer_step and delay_grad_averaging.
+
+    :param scheduler: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler
+    :note: the learning rate scheduler will adjust learning rate based on collaboration-wide epoch, not the number of
+      local calls to optimizer.step; this is required to keep different peers synchronized.
 
 
     :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
     :param matchmaking_time: when looking for group, wait for peers to join for up to this many seconds
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
     :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
       This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
-    :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many epochs
-    :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
-    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
+
+    :param offload_optimizer: offload the optimizer to host memory, saving GPU memory for parameters and gradients
+    :param delay_optimizer_step: run optimizer in background, apply results in future .step; requires offload_optimizer
+    :param delay_grad_averaging: average gradients in background; requires offload_optimizer and delay_optimizer_step
+    :note: offload_optimizer, delay_optimizer_step and delay_grad_averaging require that the optimizer is
+      created as follows: `hivemind.Optimizer(..., optimizer=callable_optimizer_factory, params=model.parameters())`
+
+    :param delay_state_averaging: if enabled (default), average parameters and extra tensors in a background thread;
+      if set to False, average parameters synchronously within the corresponding hivemind.Optimizer.step call.
+    :param average_state_every: average state (parameters, chosen opt statistics) with peers every this many **epochs**
+      This reduces the communication overhead increasing, but can cause parameters to diverge if too large
+    :note: The maximal average_state_every=num_epochs depends on how often peers diverge from each other. If peers
+      hardly ever skip averaging rounds, they can average state less frequently. Network failures, lossy gradient
+      compression and local_updates cause parameters to diverge faster and requires more frequent averaging.
+
+    :param use_local_updates: if enabled, peers will update parameters on each .step using local gradients;
+      if not enabled (default), accumulate gradients to target_batch_size, and then call .step with averaged gradients
+    :note: even if use_local_updates=True, learning rate scheduler will still be called once per target_batch_size.
+
+    :param client_mode: if True, this peer will not accept incoming connections (firewall-compatible mode)
+    :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
+    :note: client_mode=True and auxiliary=True are mutually exclusive; auxiliary also requires batch_size_per_step=None
+
+    :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
+    :param load_state_compression: compression strategy for loading state from peers, default = no compression
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param averager_opts: additional keyword arguments forwarded to both GradientAverager and TrainingStateAverager
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
     :param tracker_opts: additional keyword arguments forwarded to ProgressTracker
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
     :param verbose: if True, report internal events such as accumilating gradients and running background tasks
@@ -97,12 +137,20 @@ class Optimizer(torch.optim.Optimizer):
         matchmaking_time: Optional[float] = 15.0,
         matchmaking_time: Optional[float] = 15.0,
         averaging_timeout: Optional[float] = 300.0,
         averaging_timeout: Optional[float] = 300.0,
         load_state_timeout: float = 600.0,
         load_state_timeout: float = 600.0,
-        average_state_every: int = 1,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
-        delay_grad_averaging: bool = False,
+        offload_optimizer: Optional[bool] = None,
         delay_optimizer_step: Optional[bool] = None,
         delay_optimizer_step: Optional[bool] = None,
+        delay_grad_averaging: bool = False,
+        delay_state_averaging: bool = True,
+        average_state_every: int = 1,
+        use_local_updates: bool = False,
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
+        grad_compression: CompressionBase = NoCompression(),
+        state_averaging_compression: CompressionBase = NoCompression(),
+        load_state_compression: CompressionBase = NoCompression(),
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
         averager_opts: Optional[dict] = None,
         averager_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         tracker_opts: Optional[dict] = None,
         shutdown_timeout: float = 5,
         shutdown_timeout: float = 5,
@@ -110,31 +158,65 @@ class Optimizer(torch.optim.Optimizer):
     ):
     ):
         client_mode = client_mode if client_mode is None else dht.client_mode
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
+        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not (client_mode and auxiliary), "Client-mode peers cannot serve as auxiliaries"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
         assert not auxiliary or batch_size_per_step is None, "Auxiliary peers should not accumulate batches"
+        if callable(optimizer) and params is not None:
+            if scheduler is not None and (not callable(scheduler) or isinstance(scheduler, LRSchedulerBase)):
+                raise ValueError("For this mode, please provide scheduler factory: callable(optimizer) -> scheduler")
+        elif all(hasattr(optimizer, attr) for attr in ("param_groups", "step", "zero_grad")):
+            if offload_optimizer or delay_optimizer_step or delay_grad_averaging:
+                raise ValueError(
+                    "To enable offload_optimizer or delayed updates, please initialize Optimizer as "
+                    "hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)"
+                )
+        else:
+            raise ValueError(
+                "Please initialize the optimizer in one of the following two ways:\n"
+                "(A) hivemind.Optimizer(..., params=params, optimizer=lambda params: create_opt(params)\n"
+                "(B) hivemind.Optimizer(..., optimizer=pre_initialize_optimizer)"
+            )
+        if use_local_updates:
+            assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
+            assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
 
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
-        self.matchmaking_time, self.average_state_every = matchmaking_time, average_state_every
+        self.delay_state_averaging, self.average_state_every = delay_state_averaging, average_state_every
+        self.matchmaking_time, self.offload_optimizer = matchmaking_time, offload_optimizer
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.delay_grad_averaging, self.delay_optimizer_step = delay_grad_averaging, delay_optimizer_step
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.averaging_timeout, self.load_state_timeout = averaging_timeout, load_state_timeout
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
 
 
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
         self.status_loglevel = logging.INFO if verbose else logging.DEBUG
-        self.scheduled_round: Optional[StepControl] = None
-        self.previous_round: Optional[StepControl] = None
+        self.scheduled_grads: Optional[StepControl] = None
+        self.scheduled_state: Optional[StepControl] = None
 
 
+        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
         self.state_averager = self._make_state_averager(
         self.state_averager = self._make_state_averager(
-            optimizer=optimizer, params=params, scheduler=scheduler, **averager_opts or {}
+            optimizer=optimizer,
+            params=params,
+            scheduler=scheduler,
+            delta_rule_averaging=use_local_updates and self.delay_state_averaging,
+            compression=state_averaging_compression,
+            state_compression=load_state_compression,
+            average_opt_statistics=average_opt_statistics,
+            extra_tensors=extra_tensors,
+            **averager_opts or {},
         )
         )
-        self.grad_averager = self._make_gradient_averager(reuse_grad_buffers=reuse_grad_buffers, **averager_opts or {})
-        self.tracker = self._make_progress_tracker(target_batch_size, **tracker_opts or {})
+        if not use_local_updates:
+            self.grad_averager = self._make_gradient_averager(
+                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+            )
+        else:
+            self.grad_averager = None
+
         self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._should_check_synchronization_on_update = True  # used in self.should_load_state_from_peers
         self._schema_hash = self._compute_schema_hash()
         self._schema_hash = self._compute_schema_hash()
         self._parent_pid = os.getpid()
         self._parent_pid = os.getpid()
 
 
-        self._step_supports_amp_scaling = self.grad_averager.reuse_grad_buffers
+        self._step_supports_amp_scaling = reuse_grad_buffers
         # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
         # note: the line above is used by pytorch AMP GradScaler to enable custom behavior needed when reusing gradient
         # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
         # buffers over multiple steps (to avoid repeated unscaling). Without reuse_grad_buffers, this is not needed.
 
 
@@ -144,11 +226,11 @@ class Optimizer(torch.optim.Optimizer):
             prefix=f"{self.run_id}_state_averager",
             prefix=f"{self.run_id}_state_averager",
             allreduce_timeout=self.averaging_timeout,
             allreduce_timeout=self.averaging_timeout,
             shutdown_timeout=self.shutdown_timeout,
             shutdown_timeout=self.shutdown_timeout,
+            offload_optimizer=self.offload_optimizer,
+            custom_gradients=self.offload_optimizer,
             status_loglevel=self.status_loglevel,
             status_loglevel=self.status_loglevel,
             client_mode=self.client_mode,
             client_mode=self.client_mode,
             auxiliary=self.auxiliary,
             auxiliary=self.auxiliary,
-            offload_optimizer=True,
-            custom_gradients=True,
             start=True,
             start=True,
             **kwargs,
             **kwargs,
         )
         )
@@ -166,12 +248,13 @@ class Optimizer(torch.optim.Optimizer):
             start=True,
             start=True,
             **kwargs,
             **kwargs,
         )
         )
-        optimized_param_groups = self.state_averager.optimizer.param_groups
-        optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
-        with grad_averager.get_tensors() as averaged_gradients:
-            assert len(averaged_gradients) == len(optimized_parameters)
-            for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
-                opt_param.grad = averaged_grad
+        if self.offload_optimizer:
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad = averaged_grad
         return grad_averager
         return grad_averager
 
 
     def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
     def _make_progress_tracker(self, target_batch_size: int, **kwargs) -> ProgressTracker:
@@ -189,7 +272,9 @@ class Optimizer(torch.optim.Optimizer):
         optimized_param_groups = self.state_averager.optimizer.param_groups
         optimized_param_groups = self.state_averager.optimizer.param_groups
         optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
         optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
         param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
         param_shapes = tuple(tuple(param.shape) for param in optimized_parameters)
-        grad_ids = tuple(id(param.grad) for param in optimized_parameters)
+
+        # offloaded optimizer requires that gradient tensors are reused between iterations
+        grad_ids = tuple(id(param.grad) for param in optimized_parameters) if self.offload_optimizer else None
         return hash((grad_ids, param_shapes))
         return hash((grad_ids, param_shapes))
 
 
     def is_alive(self) -> bool:
     def is_alive(self) -> bool:
@@ -199,25 +284,13 @@ class Optimizer(torch.optim.Optimizer):
     def local_epoch(self) -> int:
     def local_epoch(self) -> int:
         return self.state_averager.local_epoch
         return self.state_averager.local_epoch
 
 
-    def should_load_state_from_peers(self) -> bool:
-        """
-        If true, peer will discard local progress and attempt to download state from peers.
-        This method allows peer to continue training in two cases:
-         - peer is on the same epoch as other collaborators - keep training normally
-         - peer was on the same epoch and accumulated some grads, but some collaborators
-             have just transitioned to the next epoch - this peer should also transition.
+    @property
+    def use_local_updates(self) -> bool:
+        return self.grad_averager is None
 
 
-        :note: The latter case occurs due to the lack of network synchrony: the first peer that
-        detects enough samples will transition to the next step and start counting samples anew.
-        Some other peers may take time before they check with DHT and observe that
-          - the global epoch is technically one epoch ahead of the current one and
-          - the remaining (non-transitioned) peers no longer have target_batch_size between them
-        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
-        """
-        if self._should_check_synchronization_on_update and self.tracker.updated_progress_this_epoch.is_set():
-            self._should_check_synchronization_on_update = False
-            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
-        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
+    @property
+    def use_gradient_averaging(self) -> bool:
+        return self.grad_averager is not None
 
 
     def step(
     def step(
         self,
         self,
@@ -241,6 +314,9 @@ class Optimizer(torch.optim.Optimizer):
             raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
             raise ValueError("Auxiliary peers should not have batch size, run closures, or use grad_scaler")
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
         batch_size = batch_size if batch_size is not None else self.batch_size_per_step
 
 
+        # if delayed updates finished before step, apply these updates; otherwise do nothing
+        self.state_averager.step(apply_delayed_updates=True)
+
         loss = None
         loss = None
         if closure is not None:
         if closure is not None:
             with torch.enable_grad():
             with torch.enable_grad():
@@ -249,111 +325,180 @@ class Optimizer(torch.optim.Optimizer):
         if not self.auxiliary and self.should_load_state_from_peers():
         if not self.auxiliary and self.should_load_state_from_peers():
             logger.log(self.status_loglevel, "Peer is out of sync.")
             logger.log(self.status_loglevel, "Peer is out of sync.")
             self.load_state_from_peers()
             self.load_state_from_peers()
-            return loss
+            return loss  # local gradients were computed with out-of-sync parameters, must start over
 
 
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
-            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
-            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
-            self.grad_averager.reset_accumulated_grads_()
-            return loss
+        if self.use_gradient_averaging:
+            # accumulate gradients toward target batch size, then aggregate with peers and run optimizer
+            if not self.auxiliary:
+                grads_are_valid = self._check_and_accumulate_gradients(batch_size, grad_scaler)
+                if not grads_are_valid:
+                    return loss  # local gradients were reset due to overflow, must start over
 
 
-        if not self.auxiliary:
-            self.grad_averager.accumulate_grads_(batch_size)
-            self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
-            self.state_averager.step(apply_delayed_updates=True)
+            self._maybe_schedule_gradient_averaging()
+            self._maybe_schedule_state_averaging()
 
 
-        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
-            if self.scheduled_round is None or self.scheduled_round.triggered or self.scheduled_round.done():
-                if self.delay_grad_averaging:
-                    # wait for previous averaging to finish before starting a new one
-                    self.state_averager.step(wait_for_delayed_update=True)
+        else:
+            # use_local_updates=True: update parameters on every step independently of other peers
+            if not self.auxiliary:
+                if grad_scaler is not None:
+                    with grad_scaler.running_global_step():
+                        assert grad_scaler.unscale_(self)
+
+                new_samples_accumulated = self.tracker.local_progress.samples_accumulated + batch_size
+                self.tracker.report_local_progress(self.local_epoch, new_samples_accumulated)
+                self._maybe_schedule_state_averaging()
+
+                self.state_averager.step(
+                    increment_epoch=False,
+                    optimizer_step=True,
+                    delay_optimizer_step=self.delay_optimizer_step,
+                    grad_scaler=grad_scaler,
+                )
 
 
-                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
-                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
-                logger.log(self.status_loglevel, f"Pre-scheduling next averaging round in {eta_seconds:.2f}s.")
-                scheduled_time = self.tracker.estimated_next_update_time
-                if self.client_mode:
-                    scheduled_time = get_dht_time() + self.averaging_timeout
-                self.scheduled_round = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+        if self.tracker.ready_to_update_epoch:
+            self._update_global_epoch(grad_scaler)
 
 
-        if not self.tracker.ready_to_update_epoch:
-            return loss
+        return loss
 
 
+    def _update_global_epoch(self, grad_scaler: Optional[GradScaler]) -> None:
+        """Depending on the configuration: aggregate gradients and/or parameters, perform global optimizer step"""
         assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
         assert self._schema_hash == self._compute_schema_hash(), "parameters or gradients changed during iteration"
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
-            # note: we do not need to replace grads because we explicitly load grads into the optimizer
-
-            logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.tracker.global_epoch}")
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self)
-
-            if self.scheduled_round is not None and self.scheduled_round.triggered or self.scheduled_round.done():
-                logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_round}")
-                self.scheduled_round = None
+            wait_for_trigger = None
 
 
-            swarm_not_empty = self.tracker.global_progress.num_peers > 1
-            began_averaging_gradients = False
-            if swarm_not_empty:
-                try:
-                    self.scheduled_round = self.grad_averager.step(
-                        control=self.scheduled_round, reset_accumulators=True, wait=False
-                    )
-                    assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
-                    began_averaging_gradients = True
-                except BaseException as e:
-                    logger.exception(e)
-
-            if not began_averaging_gradients and self.scheduled_round is not None and not self.scheduled_round.done():
-                logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
-                self.scheduled_round.cancel()
-                self.scheduled_round = None
-
-            if not self.delay_grad_averaging:
-                self._average_gradients_and_load_into_optimizer(self.scheduled_round)
+            if self.use_gradient_averaging:
+                logger.log(self.status_loglevel, f"Beginning optimizer step #{self.local_epoch}")
+                began_averaging_gradients = self._begin_averaging_gradients(grad_scaler)
+                if not began_averaging_gradients:
+                    return  # failed to start gradient averaging due to an internal error
+                if self.delay_grad_averaging:
+                    # if using delayed grad averaing, send this to state_averager as a pre-condition for optimizer step
+                    wait_for_trigger = partial(self._average_gradients_and_load_into_optimizer, self.scheduled_grads)
+                else:
+                    # delay_grad_averaging=False, average gradients immediately
+                    self._average_gradients_and_load_into_optimizer(self.scheduled_grads)
 
 
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
             next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+            swarm_not_empty = self.tracker.global_progress.num_peers > 1
+            should_perform_optimizer_step = not self.auxiliary and not self.use_local_updates
+            should_average_state = swarm_not_empty and next_epoch % self.average_state_every == 0
 
 
             self.state_averager.step(
             self.state_averager.step(
                 increment_epoch=True,
                 increment_epoch=True,
-                optimizer_step=not self.auxiliary,
-                delay_optimizer_step=self.delay_optimizer_step,
-                averaging_round=swarm_not_empty and next_epoch % self.average_state_every == 0,
-                delay_averaging=not self.auxiliary,
+                wait_for_trigger=wait_for_trigger,
+                optimizer_step=should_perform_optimizer_step,
+                delay_optimizer_step=self.delay_optimizer_step and should_perform_optimizer_step,
                 grad_scaler=grad_scaler,
                 grad_scaler=grad_scaler,
-                wait_for_trigger=partial(self._average_gradients_and_load_into_optimizer, self.scheduled_round)
-                if self.delay_grad_averaging
-                else None,
-                averaging_opts=dict(
-                    scheduled_time=get_dht_time() + self.matchmaking_time, timeout=self.averaging_timeout
-                )
-                if swarm_not_empty and next_epoch % self.average_state_every == 0
-                else None,
+                averaging_round=should_average_state,
+                delay_averaging=self.delay_state_averaging and not self.auxiliary,
+                averaging_control=self.scheduled_state if should_average_state else None,
+                averaging_opts=dict(timeout=self.averaging_timeout) if should_average_state else None,
             )
             )
 
 
-            if not self.auxiliary:
-                self.grad_averager.reset_accumulated_grads_()
-                self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
-                self._should_check_synchronization_on_update = True
+            self.tracker.update_epoch(new_epoch=self.state_averager.local_epoch)
+            self.scheduled_grads = self.scheduled_state = None
+            self._should_check_synchronization_on_update = True
+            # the above line ensures that peers check for *strict* synchronization once per epoch
 
 
             if not self.client_mode:
             if not self.client_mode:
-                self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
 
 
+            if self.use_gradient_averaging and not self.auxiliary:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
             logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
             logger.log(self.status_loglevel, f"Transitioning to epoch {self.local_epoch}.")
-        return loss
+
+    def _begin_averaging_gradients(self, grad_scaler: Optional[GradScaler]) -> bool:
+        """Begin an all-reduce round to average gradients; return True if succeeded, False if failed"""
+        if grad_scaler is not None:
+            with grad_scaler.running_global_step():
+                assert grad_scaler.unscale_(self)
+
+        if self.scheduled_grads is not None and (self.scheduled_grads.triggered or self.scheduled_grads.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {self.scheduled_grads}")
+            self.scheduled_grads = None
+
+        began_averaging_gradients = False
+        if self.tracker.global_progress.num_peers > 1:
+            try:
+                self.scheduled_grads = self.grad_averager.step(
+                    control=self.scheduled_grads, reset_accumulators=True, wait=False
+                )
+                assert self.grad_averager.local_samples_accumulated == 0, "step should have reset accumulators"
+                began_averaging_gradients = True
+            except BaseException as e:
+                logger.exception(e)
+
+        if not began_averaging_gradients and self.scheduled_grads is not None and not self.scheduled_grads.done():
+            logger.log(self.status_loglevel, f"Cancelled pre-scheduled averaging round")
+            self.scheduled_grads.cancel()
+            self.scheduled_grads = None
+        return began_averaging_gradients
+
+    def _check_and_accumulate_gradients(self, batch_size: int, grad_scaler: Optional[GradScaler]) -> bool:
+        """Check if gradients are valid, accumulate and return True; otherwise, reset and return False"""
+        assert not self.use_local_updates and not self.auxiliary
+        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
+            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
+            self.tracker.report_local_progress(self.local_epoch, samples_accumulated=0)
+            self.grad_averager.reset_accumulated_grads_()
+            return False
+
+        self.grad_averager.accumulate_grads_(batch_size)
+        self.tracker.report_local_progress(self.local_epoch, self.grad_averager.local_samples_accumulated)
+        return True
+
+    def _maybe_schedule_gradient_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next gradient averaging round at the estimated end of epoch"""
+        assert self.use_gradient_averaging
+        if self.tracker.estimated_next_update_time - get_dht_time() <= self.matchmaking_time:
+            if self.scheduled_grads is None or self.scheduled_grads.triggered or self.scheduled_grads.done():
+                if self.delay_grad_averaging:
+                    # wait for previous averaging to finish before starting a new one
+                    self.state_averager.step(wait_for_delayed_updates=True)
+
+                eta_seconds = self.tracker.estimated_next_update_time - get_dht_time()
+                eta_seconds = max(eta_seconds, self.grad_averager.matchmaking_kwargs["min_matchmaking_time"])
+                logger.log(self.status_loglevel, f"Pre-scheduling gradient averaging round in {eta_seconds:.2f}s.")
+                scheduled_time = self.tracker.estimated_next_update_time
+                if self.client_mode:
+                    scheduled_time = get_dht_time() + self.averaging_timeout
+                self.scheduled_grads = self.grad_averager.schedule_step(scheduled_time, timeout=self.averaging_timeout)
+
+    def _maybe_schedule_state_averaging(self) -> None:
+        """If next epoch is coming soon, schedule the next state averaging at estimated parameter averaging start"""
+        next_epoch = max(self.local_epoch + 1, self.tracker.global_epoch)
+        if next_epoch % self.average_state_every != 0:
+            return  # averaging is not performed at this epoch
+
+        estimated_time = self.tracker.estimated_next_update_time
+        estimated_time += self.state_averager.delay_before_averaging.ema_seconds_per_sample
+        eta_seconds_to_averaging = self.tracker.estimated_next_update_time - get_dht_time()
+
+        if eta_seconds_to_averaging <= self.matchmaking_time:
+            if self.scheduled_state is None or self.scheduled_state.triggered or self.scheduled_state.done():
+                min_matchmaking_time = self.state_averager.matchmaking_kwargs["min_matchmaking_time"]
+                actual_seconds = max(eta_seconds_to_averaging, min_matchmaking_time)
+                logger.log(self.status_loglevel, f"Pre-scheduling state averaging round in {actual_seconds:.2f}s.")
+                if self.client_mode:
+                    estimated_time = get_dht_time() + self.averaging_timeout
+                self.scheduled_state = self.state_averager.schedule_step(
+                    estimated_time, gather=next_epoch, timeout=self.averaging_timeout
+                )
 
 
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
     def _average_gradients_and_load_into_optimizer(self, maybe_step_control: Optional[StepControl]):
         """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
         """Run gradient averaging; on success, feed averaged gradients into optimizer; else, use local gradients"""
-        assert maybe_step_control is None or maybe_step_control.triggered
+        assert self.use_gradient_averaging and maybe_step_control is None or maybe_step_control.triggered
         averaged_gradients = False
         averaged_gradients = False
 
 
         try:
         try:
             if maybe_step_control is not None:
             if maybe_step_control is not None:
                 group_info = maybe_step_control.result(self.averaging_timeout)
                 group_info = maybe_step_control.result(self.averaging_timeout)
                 logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
                 logger.log(self.status_loglevel, f"Averaged gradients with {len(group_info)} peers")
+                self._load_averaged_gradients_into_optimizer_()
                 averaged_gradients = True
                 averaged_gradients = True
             else:
             else:
                 logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
                 logger.log(self.status_loglevel, f"Skipped averaging: there are no other peers")
@@ -363,28 +508,65 @@ class Optimizer(torch.optim.Optimizer):
         if not averaged_gradients:
         if not averaged_gradients:
             logger.log(self.status_loglevel, f"Proceeding with local gradients")
             logger.log(self.status_loglevel, f"Proceeding with local gradients")
             self.grad_averager.load_accumulators_into_averager_()
             self.grad_averager.load_accumulators_into_averager_()
+            self._load_averaged_gradients_into_optimizer_()
+
+    def _load_averaged_gradients_into_optimizer_(self):
+        """If required, load averaged gradients into optimizer; otherwise simply notify grad averager"""
+        assert self.use_gradient_averaging
+
+        if self.offload_optimizer:
+            pass  # averaged gradients are already baked into optimizer, see _make_gradient_averager
+        else:
+            # copy averaged gradients into optimizer .grad buffers
+            optimized_param_groups = self.state_averager.optimizer.param_groups
+            optimized_parameters = [param for group in optimized_param_groups for param in group["params"]]
+            with torch.no_grad(), self.grad_averager.get_tensors() as averaged_gradients:
+                assert len(averaged_gradients) == len(optimized_parameters)
+                for opt_param, averaged_grad in zip(optimized_parameters, averaged_gradients):
+                    opt_param.grad.copy_(averaged_grad, non_blocking=True)
 
 
         self.grad_averager.notify_used_averaged_gradients()
         self.grad_averager.notify_used_averaged_gradients()
 
 
     def zero_grad(self, set_to_none: bool = False):
     def zero_grad(self, set_to_none: bool = False):
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
         """Reset gradients from model. If these gradients are reused for accumulators, raise an error."""
-        if self.grad_averager.reuse_grad_buffers:
+        if self.use_gradient_averaging and self.grad_averager.reuse_grad_buffers:
             raise ValueError(
             raise ValueError(
                 f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
                 f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
                 f"call zero_grad manually. Gradients will be refreshed internally."
                 f"call zero_grad manually. Gradients will be refreshed internally."
             )
             )
-        for param in self.grad_averager.parameters:
-            if param.grad is None:
-                pass
-            elif set_to_none:
-                param.grad = None
-            else:
-                param.grad.zero_()
+        for param_group in self.param_groups:
+            for param in param_group["params"]:
+                if param.grad is None:
+                    pass
+                elif set_to_none:
+                    param.grad = None
+                else:
+                    param.grad.zero_()
+
+    def should_load_state_from_peers(self) -> bool:
+        """
+        If true, peer will discard local progress and attempt to download state from peers.
+        This method allows peer to continue training in two cases:
+         - peer is on the same epoch as other collaborators - keep training normally
+         - peer was on the same epoch and accumulated some grads, but some collaborators
+             have just transitioned to the next epoch - this peer should also transition.
+
+        :note: The latter case occurs due to the lack of network synchrony: the first peer that
+        detects enough samples will transition to the next step and start counting samples anew.
+        Some other peers may take time before they check with DHT and observe that
+          - the global epoch is technically one epoch ahead of the current one and
+          - the remaining (non-transitioned) peers no longer have target_batch_size between them
+        If this is the case, peer should transition to the next epoch and does *not* need to re-load state.
+        """
+        if self._should_check_synchronization_on_update and self.tracker.fetched_global_progress_this_epoch.is_set():
+            self._should_check_synchronization_on_update = False
+            return self.local_epoch != self.tracker.global_epoch  # require exact synchronization once per step
+        return self.local_epoch < self.tracker.global_epoch - 1  # catch up if a peer just switched to next epoch
 
 
     def load_state_from_peers(self, **kwargs):
     def load_state_from_peers(self, **kwargs):
         """Attempt to fetch the newest collaboration state from other peers"""
         """Attempt to fetch the newest collaboration state from other peers"""
-        if self.scheduled_round is not None and not self.scheduled_round.done():
-            self.scheduled_round.cancel()
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self.scheduled_grads.cancel()
 
 
         with self.tracker.pause_updates():
         with self.tracker.pause_updates():
             while True:
             while True:
@@ -402,11 +584,23 @@ class Optimizer(torch.optim.Optimizer):
                 self.state_averager.local_epoch = self.tracker.global_epoch
                 self.state_averager.local_epoch = self.tracker.global_epoch
 
 
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
             self.tracker.report_local_progress(local_epoch=self.local_epoch, samples_accumulated=0)
-            self.grad_averager.reset_accumulated_grads_()
+
             if not self.client_mode:
             if not self.client_mode:
-                self.grad_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
                 self.state_averager.state_sharing_priority = self.local_epoch
 
 
+            self._cancel_scheduled_averaging()
+
+            if self.use_gradient_averaging:
+                self.grad_averager.reset_accumulated_grads_()
+                if not self.client_mode:
+                    self.grad_averager.state_sharing_priority = self.local_epoch
+
+    def _cancel_scheduled_averaging(self):
+        if self.scheduled_grads is not None and not self.scheduled_grads.done():
+            self.scheduled_grads.cancel()
+        if self.scheduled_state is not None and not self.scheduled_state.done():
+            self.scheduled_state.cancel()
+
     def state_dict(self) -> dict:
     def state_dict(self) -> dict:
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict = self.state_averager.optimizer.state_dict()
         state_dict["state"]["local_epoch"] = self.local_epoch
         state_dict["state"]["local_epoch"] = self.local_epoch
@@ -448,11 +642,13 @@ class Optimizer(torch.optim.Optimizer):
 
 
     def shutdown(self):
     def shutdown(self):
         logger.debug("Sending goodbye to peers...")
         logger.debug("Sending goodbye to peers...")
+        self._cancel_scheduled_averaging()
         self.tracker.shutdown(self.shutdown_timeout)
         self.tracker.shutdown(self.shutdown_timeout)
-        logger.debug("Shutting down averager...")
-        self.state_averager.step(wait_for_delayed_update=True)
+        logger.debug("Shutting down averagers...")
+        self.state_averager.step(wait_for_delayed_updates=True)
         self.state_averager.shutdown()
         self.state_averager.shutdown()
-        self.grad_averager.shutdown()
+        if self.use_gradient_averaging:
+            self.grad_averager.shutdown()
         logger.debug(f"{self.__class__.__name__} is shut down.")
         logger.debug(f"{self.__class__.__name__} is shut down.")
 
 
     def __del__(self):
     def __del__(self):

+ 10 - 5
hivemind/optim/experimental/progress_tracker.py

@@ -114,7 +114,7 @@ class ProgressTracker(threading.Thread):
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         metadata, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.global_progress = self._parse_swarm_progress_data(metadata)
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
         self.lock_global_progress, self.global_state_updated = threading.Lock(), threading.Event()
-        self.should_report_progress, self.updated_progress_this_epoch = threading.Event(), threading.Event()
+        self.should_report_progress, self.fetched_global_progress_this_epoch = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         self.shutdown_triggered, self.shutdown_complete = threading.Event(), threading.Event()
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         super().__init__(name=f"{self.__class__.__name__}({self.prefix})", daemon=daemon)
         if start:
         if start:
@@ -150,15 +150,20 @@ class ProgressTracker(threading.Thread):
             client_mode=self.client_mode,
             client_mode=self.client_mode,
         )
         )
 
 
-    def report_local_progress(self, local_epoch: int, samples_accumulated: int):
+    def report_local_progress(self, local_epoch: int, samples_accumulated: int, update_global_samples: bool = True):
         """Update the number of locally accumulated samples and notify to other peers about this."""
         """Update the number of locally accumulated samples and notify to other peers about this."""
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
         extra_samples = samples_accumulated - self.local_progress.samples_accumulated
+        if update_global_samples and local_epoch == self.local_progress.epoch == self.global_progress.epoch:
+            self.global_progress.samples_accumulated += extra_samples
+            # note: the above line can decrease the number of samples, e.g. if forced to reset due to overflow
+
         if extra_samples > 0:
         if extra_samples > 0:
             self.performance_ema.update(task_size=extra_samples)
             self.performance_ema.update(task_size=extra_samples)
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
             logger.debug(f"Updated performance EMA: {self.performance_ema.samples_per_second:.5f}")
         else:
         else:
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             logger.debug("Resetting performance timestamp to current time (progress was reset)")
             self.performance_ema.reset_timer()
             self.performance_ema.reset_timer()
+
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.local_progress = self._get_local_progress(local_epoch, samples_accumulated)
         self.should_report_progress.set()
         self.should_report_progress.set()
 
 
@@ -178,7 +183,7 @@ class ProgressTracker(threading.Thread):
             self.global_progress.samples_accumulated = 0
             self.global_progress.samples_accumulated = 0
             self.global_progress.eta_next_epoch = float("inf")
             self.global_progress.eta_next_epoch = float("inf")
         self.report_local_progress(new_epoch, samples_accumulated=0)
         self.report_local_progress(new_epoch, samples_accumulated=0)
-        self.updated_progress_this_epoch.clear()
+        self.fetched_global_progress_this_epoch.clear()
         return new_epoch
         return new_epoch
 
 
     def run(self):
     def run(self):
@@ -257,7 +262,7 @@ class ProgressTracker(threading.Thread):
                         break
                         break
                     metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     metadata = maybe_metadata.value if isinstance(maybe_metadata, ValueWithExpiration) else None
                     self.global_progress = self._parse_swarm_progress_data(metadata)
                     self.global_progress = self._parse_swarm_progress_data(metadata)
-                    self.updated_progress_this_epoch.set()
+                    self.fetched_global_progress_this_epoch.set()
 
 
         finally:
         finally:
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
             logger.log(self.status_loglevel, f"No longer fetching {self.training_progress_key}.")
@@ -321,7 +326,7 @@ class ProgressTracker(threading.Thread):
         )
         )
         logger.log(
         logger.log(
             self.status_loglevel,
             self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples for iteration #{global_epoch} from "
+            f"{self.prefix} accumulated {total_samples_accumulated} samples for epoch #{global_epoch} from "
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
             f"{num_peers} peers. ETA {estimated_time_to_next_epoch:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
         )
         )
         return GlobalTrainingProgress(
         return GlobalTrainingProgress(

+ 149 - 75
hivemind/optim/experimental/state_averager.py

@@ -1,19 +1,20 @@
 """ An extension of averager that supports common optimization use cases. """
 """ An extension of averager that supports common optimization use cases. """
 import logging
 import logging
-from asyncio import Future
+import time
 from concurrent.futures import ThreadPoolExecutor
 from concurrent.futures import ThreadPoolExecutor
 from contextlib import nullcontext
 from contextlib import nullcontext
 from itertools import chain
 from itertools import chain
-from threading import Event
+import threading
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
 
 
 import torch
 import torch
 
 
 import hivemind
 import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.compression import CompressionInfo, TensorRole
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
-from hivemind.utils import get_logger, nested_flatten, nested_pack
+from hivemind.utils import get_logger, nested_flatten, nested_pack, get_dht_time, DHTExpiration, PerformanceEMA
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -61,6 +62,9 @@ class TrainingStateAverager(DecentralizedAverager):
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
       This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
     :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
       For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+      Defaults to True if offload_optimizer, False otherwise.
+    :param delta_rule_averaging: if True, averaging will use delta rule to allow running local optimizer steps
+      while averaging. Delta rule: `state_tensor := state_tensor + averaging_result - state_tensor_before_averaging`
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
     :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
     :param parameter_names: optionally provide parameter names in the same order as in params
     :param parameter_names: optionally provide parameter names in the same order as in params
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -79,7 +83,9 @@ class TrainingStateAverager(DecentralizedAverager):
         initialize_optimizer: Optional[bool] = None,
         initialize_optimizer: Optional[bool] = None,
         offload_optimizer: bool = False,
         offload_optimizer: bool = False,
         custom_gradients: bool = False,
         custom_gradients: bool = False,
-        reuse_tensors: bool = False,
+        reuse_tensors: Optional[bool] = None,
+        delta_rule_averaging: bool = False,
+        performance_ema_alpha: float = 0.1,
         sync_epoch_when_averaging: bool = False,
         sync_epoch_when_averaging: bool = False,
         parameter_names: Optional[Sequence[str]] = None,
         parameter_names: Optional[Sequence[str]] = None,
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
@@ -89,17 +95,19 @@ class TrainingStateAverager(DecentralizedAverager):
     ):
     ):
         average_opt_statistics = tuple(average_opt_statistics)
         average_opt_statistics = tuple(average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
         assert all(isinstance(key, str) for key in average_opt_statistics)
-        if offload_optimizer and reuse_tensors:
-            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if reuse_tensors is None:
+            reuse_tensors = offload_optimizer and not delta_rule_averaging
         if custom_gradients and not offload_optimizer:
         if custom_gradients and not offload_optimizer:
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
             logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+        if reuse_tensors and delta_rule_averaging:
+            raise ValueError("reuse_tensors and delta_rule_averaging are mutually exclusive")
 
 
         param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
         param_groups, main_parameters, parameter_names = self._check_params(optimizer, params, parameter_names)
 
 
         self.status_loglevel = status_loglevel
         self.status_loglevel = status_loglevel
-        self.reuse_tensors = reuse_tensors
-        self.offload_optimizer = offload_optimizer
-        self.custom_gradients = custom_gradients
+        self.offload_optimizer, self.custom_gradients = offload_optimizer, custom_gradients
+        self.reuse_tensors, self.delta_rule_averaging = reuse_tensors, delta_rule_averaging
+        self._old_tensors: Optional[Sequence[torch.Tensor]] = None  # for delta rule
 
 
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self.main_parameters, self.parameter_names = main_parameters, parameter_names
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
         self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
@@ -110,11 +118,12 @@ class TrainingStateAverager(DecentralizedAverager):
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.sync_epoch_when_averaging = sync_epoch_when_averaging
         self.local_epoch = 0
         self.local_epoch = 0
 
 
-        self.step_executor = ThreadPoolExecutor(max_workers=1)
-        self.finished_optimizer_step = Event()
-        self.finished_averaging_round = Event()
-        self.pending_update = Future()
-        self.pending_update.set_result(None)
+        self.delay_before_averaging = PerformanceEMA(alpha=performance_ema_alpha)
+        self.step_executor = ThreadPoolExecutor(max_workers=2 if self.delta_rule_averaging else 1)
+        self.finished_optimizer_step = threading.Event()
+        self.finished_averaging_round = threading.Event()
+        self.lock_optimizer = threading.Lock()
+        self.pending_updates = set()
 
 
         super().__init__(
         super().__init__(
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
             dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
@@ -147,7 +156,8 @@ class TrainingStateAverager(DecentralizedAverager):
     def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
     def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
         """Create a new tensor for averaging or reuse the existing one"""
         """Create a new tensor for averaging or reuse the existing one"""
         if self.reuse_tensors:
         if self.reuse_tensors:
-            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+            if source_tensor.device != torch.device("cpu"):
+                raise ValueError("reuse_tensors is only supported if all averaged tensors are on CPU.")
             if not source_tensor.is_shared():
             if not source_tensor.is_shared():
                 source_tensor.share_memory_()
                 source_tensor.share_memory_()
             return source_tensor
             return source_tensor
@@ -174,19 +184,26 @@ class TrainingStateAverager(DecentralizedAverager):
         # create optimizer
         # create optimizer
         if optimizer_is_factory:
         if optimizer_is_factory:
             if self.offload_optimizer:
             if self.offload_optimizer:
-                for param in self._averaged_parameters:
-                    if param.grad is None:
-                        param.grad = torch.zeros_like(param)
+                if self.reuse_tensors:
+                    parameters_for_optimizer = self._averaged_parameters
+                else:
+                    parameters_for_optimizer = tuple(
+                        tensor.detach().clone().requires_grad_(tensor.requires_grad)
+                        for tensor in self._averaged_parameters
+                    )
 
 
                 next_index = 0
                 next_index = 0
                 param_groups_for_optimizer = []
                 param_groups_for_optimizer = []
                 for param_group in param_groups:
                 for param_group in param_groups:
                     num_params = len(param_group["params"])
                     num_params = len(param_group["params"])
-                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    averaged_params_for_group = parameters_for_optimizer[next_index : next_index + num_params]
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
                     next_index += num_params
                     next_index += num_params
-                assert next_index == len(self._averaged_parameters)
+                assert next_index == len(parameters_for_optimizer)
 
 
+                for param in parameters_for_optimizer:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
             else:
             else:
                 param_groups_for_optimizer = param_groups
                 param_groups_for_optimizer = param_groups
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
             optimizer = optimizer_or_factory(param_groups_for_optimizer)
@@ -214,7 +231,7 @@ class TrainingStateAverager(DecentralizedAverager):
 
 
         # verify optimizer and scheduler
         # verify optimizer and scheduler
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
         assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
-        if self.offload_optimizer or self.reuse_tensors:
+        if self.reuse_tensors:
             for param_group in optimizer.param_groups:
             for param_group in optimizer.param_groups:
                 for param in param_group["params"]:
                 for param in param_group["params"]:
                     assert param.is_shared()
                     assert param.is_shared()
@@ -237,7 +254,7 @@ class TrainingStateAverager(DecentralizedAverager):
     def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
     def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
         """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
         """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
         assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
         assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
-        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert hasattr(self, "_averaged_parameters"), "ShoTrueuld initialize _averaged_parameters first"
         assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
         assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
         assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
         assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
 
 
@@ -251,7 +268,7 @@ class TrainingStateAverager(DecentralizedAverager):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
         for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
             assert local_tensor.shape == averaged_tensor.shape
             assert local_tensor.shape == averaged_tensor.shape
             if averaged_tensor.grad is not None:
             if averaged_tensor.grad is not None:
-                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+                logger.log(self.status_loglevel, "setting gradients for averaged tensor to None")
 
 
         return averaged_tensors
         return averaged_tensors
 
 
@@ -275,9 +292,22 @@ class TrainingStateAverager(DecentralizedAverager):
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
             tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
         return tuple(tensor_infos)
         return tuple(tensor_infos)
 
 
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into TrainingStateAverager.step to use pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
     def step(
     def step(
         self,
         self,
-        wait_for_delayed_update: bool = None,
+        wait_for_delayed_updates: bool = None,
         apply_delayed_updates: bool = True,
         apply_delayed_updates: bool = True,
         increment_epoch: bool = False,
         increment_epoch: bool = False,
         optimizer_step: bool = False,
         optimizer_step: bool = False,
@@ -285,6 +315,7 @@ class TrainingStateAverager(DecentralizedAverager):
         delay_optimizer_step: bool = False,
         delay_optimizer_step: bool = False,
         averaging_round: bool = False,
         averaging_round: bool = False,
         delay_averaging: Optional[bool] = None,
         delay_averaging: Optional[bool] = None,
+        averaging_control: Optional[StepControl] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         wait_for_trigger: Optional[Callable[[], Any]] = None,
         grad_scaler: Optional[GradScaler] = None,
         grad_scaler: Optional[GradScaler] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
         averaging_opts: Optional[Dict[str, Any]] = None,
@@ -293,7 +324,7 @@ class TrainingStateAverager(DecentralizedAverager):
         Perform one or several possible actions, depending on the specified keyword args.
         Perform one or several possible actions, depending on the specified keyword args.
         The actions will be performed in the same order as specified below:
         The actions will be performed in the same order as specified below:
 
 
-        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+        :param wait_for_delayed_updates: if there are background averaging rounds, wait for them to finish
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
           by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
         :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
@@ -304,6 +335,7 @@ class TrainingStateAverager(DecentralizedAverager):
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
         :param delay_averaging: if True, perform averaging in background and apply results in a future step
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
           by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_control: if specified, use this as a pre-scheduled averaging round. Should require_trigger.
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :param wait_for_trigger: wait for this (non-asyncio) function to finish before running optimizer step
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :note: if wait_for_trigger fails with any exception, it will abort optimizer step, zero grad and averaging
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
         :param grad_scaler: when using hivemind.GradScaler, one must forward it to step after calling .unscale_
@@ -311,16 +343,19 @@ class TrainingStateAverager(DecentralizedAverager):
         """
         """
         if delay_averaging is None:
         if delay_averaging is None:
             delay_averaging = delay_optimizer_step
             delay_averaging = delay_optimizer_step
-        if wait_for_delayed_update is None:
-            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        should_wait = averaging_round or optimizer_step or zero_grad if self.delta_rule_averaging else averaging_round
+        if wait_for_delayed_updates is None:
+            wait_for_delayed_updates = should_wait
+        if should_wait and not (wait_for_delayed_updates and apply_delayed_updates):
+            raise ValueError("Should wait for background operation to finish before scheduling new one")
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
         assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
-        if optimizer_step or averaging_round or zero_grad:
-            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
         if delay_optimizer_step:
         if delay_optimizer_step:
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
             assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
         if averaging_opts and not averaging_round:
         if averaging_opts and not averaging_round:
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
             logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_opts}")
+        if averaging_control is not None:
+            assert averaging_round, "averaging_control is unused if averaging_round is not performed"
         if wait_for_trigger is not None:
         if wait_for_trigger is not None:
             assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
             assert optimizer_step or zero_grad or averaging_round, "trigger is only used for updating parameters"
             if not (self.reuse_tensors or self.custom_gradients):
             if not (self.reuse_tensors or self.custom_gradients):
@@ -333,68 +368,83 @@ class TrainingStateAverager(DecentralizedAverager):
                 )
                 )
         output = None
         output = None
 
 
-        if wait_for_delayed_update:
-            if not self.pending_update.done():
-                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
-                output = self.pending_update.result()
-
-        if self.pending_update.done() and self.pending_update.exception():
-            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+        if wait_for_delayed_updates:
+            for pending_update in self.pending_updates:
+                try:
+                    logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                    output = pending_update.result()
+                except BaseException:
+                    pass  # exception will be reported below
+
+        # remove finished updates, log any exceptions
+        finished_updates = {pending_update for pending_update in self.pending_updates if pending_update.done()}
+        self.pending_updates = {pending_update for pending_update in self.pending_updates if not pending_update.done()}
+        for finished_update in finished_updates:
+            if finished_update.exception():
+                logger.log(self.status_loglevel, f"Background update failed with {finished_update.exception()}")
 
 
         if apply_delayed_updates:
         if apply_delayed_updates:
             if self.finished_averaging_round.is_set():
             if self.finished_averaging_round.is_set():
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
                     self._apply_averaging_results_()
-                    if self.offload_optimizer:
-                        self._apply_optimizer_parameters_()
+                if self.offload_optimizer and not self.finished_optimizer_step.is_set():
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 logger.log(self.status_loglevel, "Received parameters from background averaging round")
                 self.finished_averaging_round.clear()
                 self.finished_averaging_round.clear()
 
 
             if self.finished_optimizer_step.is_set():
             if self.finished_optimizer_step.is_set():
                 if self.offload_optimizer:
                 if self.offload_optimizer:
                     self._apply_optimizer_parameters_()
                     self._apply_optimizer_parameters_()
-                logger.log(self.status_loglevel, "Received parameters from background optimizer step")
+                logger.debug("Received parameters from background optimizer step")
                 self.finished_optimizer_step.clear()
                 self.finished_optimizer_step.clear()
 
 
         if increment_epoch:
         if increment_epoch:
             self.local_epoch += 1
             self.local_epoch += 1
 
 
         if optimizer_step or zero_grad or averaging_round:
         if optimizer_step or zero_grad or averaging_round:
-            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
-
             if self.offload_optimizer and not self.custom_gradients:
             if self.offload_optimizer and not self.custom_gradients:
                 self._load_local_grads_into_optimizer_()
                 self._load_local_grads_into_optimizer_()
 
 
-            self.pending_update = self.step_executor.submit(
+            pending_update = self.step_executor.submit(
                 self._do,
                 self._do,
                 wait_for_trigger,
                 wait_for_trigger,
                 optimizer_step,
                 optimizer_step,
                 zero_grad,
                 zero_grad,
                 averaging_round,
                 averaging_round,
+                averaging_control,
                 grad_scaler,
                 grad_scaler,
                 **averaging_opts or {},
                 **averaging_opts or {},
             )
             )
+            self.pending_updates.add(pending_update)
+
+            should_await_optimizer = (optimizer_step or zero_grad) and not delay_optimizer_step
+            should_await_averaging = averaging_round and not delay_averaging
 
 
-            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+            if should_await_optimizer:
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.wait()
                 self.finished_optimizer_step.clear()
                 self.finished_optimizer_step.clear()
-                if self.offload_optimizer:
+                if self.offload_optimizer and not should_await_averaging:
                     self._apply_optimizer_parameters_()
                     self._apply_optimizer_parameters_()
-                logger.log(self.status_loglevel, "Finished optimizer step")
+                logger.debug("Finished optimizer step")
 
 
-            if averaging_round and not delay_averaging:
+            if should_await_averaging:
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.wait()
                 self.finished_averaging_round.clear()
                 self.finished_averaging_round.clear()
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._apply_averaging_results_()
                     self._apply_averaging_results_()
+                if self.offload_optimizer:
+                    self._apply_optimizer_parameters_()
                 logger.log(self.status_loglevel, "Finished averaging round")
                 logger.log(self.status_loglevel, "Finished averaging round")
 
 
-            if not delay_averaging:
+            async_averaging = averaging_round and delay_averaging
+            async_optimizer = (optimizer_step or zero_grad) and delay_optimizer_step
+
+            if not (async_averaging or async_optimizer):
                 try:
                 try:
-                    output = self.pending_update.result()
+                    output = pending_update.result()
                 finally:
                 finally:
-                    self.finished_averaging_round.clear()
-                    self.finished_optimizer_step.clear()
+                    self.pending_updates.remove(pending_update)
+
         return output
         return output
 
 
     def _do(
     def _do(
@@ -403,6 +453,7 @@ class TrainingStateAverager(DecentralizedAverager):
         optimizer_step: bool,
         optimizer_step: bool,
         zero_grad: bool,
         zero_grad: bool,
         averaging_round: bool,
         averaging_round: bool,
+        averaging_control: Optional[StepControl],
         grad_scaler: Optional[GradScaler],
         grad_scaler: Optional[GradScaler],
         timeout: Optional[float] = None,
         timeout: Optional[float] = None,
         **kwargs,
         **kwargs,
@@ -411,50 +462,64 @@ class TrainingStateAverager(DecentralizedAverager):
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
         This method is meant to be called in the background executor.
         This method is meant to be called in the background executor.
         """
         """
+        if averaging_control is not None and (averaging_control.triggered or averaging_control.done()):
+            logger.log(self.status_loglevel, f"Discarding failed matchmaking results: {averaging_control}")
+            averaging_control = None
+
+        start_time = time.perf_counter()
         began_running = False
         began_running = False
-        averaging_control = None
 
 
         try:
         try:
-            if averaging_round:
+            if averaging_round and averaging_control is None:
                 averaging_control = super().step(
                 averaging_control = super().step(
-                    gather=self.local_epoch, require_trigger=True, timeout=timeout, wait=False, **kwargs
+                    scheduled_time=get_dht_time() + self.delay_before_averaging.ema_seconds_per_sample,
+                    gather=self.local_epoch,
+                    require_trigger=True,
+                    timeout=timeout,
+                    wait=False,
+                    **kwargs,
                 )
                 )
 
 
             if wait_for_trigger is not None:
             if wait_for_trigger is not None:
                 wait_for_trigger()
                 wait_for_trigger()
             began_running = True
             began_running = True
 
 
-            if optimizer_step:
-                with self.lock_averaged_tensors if self.offload_optimizer or self.reuse_tensors else nullcontext():
-                    logger.log(self.status_loglevel, f"Running optimizer step")
-                    if grad_scaler is None:
-                        self.optimizer.step()
-                    else:
-                        with grad_scaler.running_global_step():
-                            assert grad_scaler.step(self.optimizer)
-
-            self._update_scheduler()
-
-            if zero_grad:
-                logger.log(self.status_loglevel, f"Running zero grad")
-                self.optimizer.zero_grad()
-                if self.offload_optimizer:
-                    for parameter in self.main_parameters:
-                        if parameter.grad is not None:
-                            parameter.grad.zero_()
+            with self.lock_optimizer:
+                if optimizer_step:
+                    with self.lock_averaged_tensors if self.reuse_tensors else nullcontext():
+                        logger.debug(f"Running optimizer step")
+                        if grad_scaler is None:
+                            self.optimizer.step()
+                        else:
+                            with grad_scaler.running_global_step():
+                                assert grad_scaler.step(self.optimizer)
+
+                if zero_grad:
+                    logger.debug(f"Running zero grad")
+                    self.optimizer.zero_grad()
+                    if self.offload_optimizer:
+                        for parameter in self.main_parameters:
+                            if parameter.grad is not None:
+                                parameter.grad.zero_()
 
 
-            self.finished_optimizer_step.set()
+                self._update_scheduler()
+                self.finished_optimizer_step.set()
 
 
             if averaging_round:
             if averaging_round:
                 if not self.reuse_tensors:
                 if not self.reuse_tensors:
                     self._load_local_tensors_into_averager_()
                     self._load_local_tensors_into_averager_()
+                if self.delta_rule_averaging:
+                    # remember tensors before averaging, update by (new_averaged_tensors - old_averaged_tensors)
+                    with torch.no_grad(), self.get_tensors() as averaged_tensors:
+                        self._old_tensors = tuple(x.cpu().clone() for x in averaged_tensors)
+
+                self.delay_before_averaging.update(task_size=1, interval=time.perf_counter() - start_time)
                 try:
                 try:
                     averaging_control.allow_allreduce()
                     averaging_control.allow_allreduce()
                     gathered = averaging_control.result(timeout=timeout)
                     gathered = averaging_control.result(timeout=timeout)
                     logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
                     logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
                 except BaseException as e:
                 except BaseException as e:
                     logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
                     logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
-                    self.finished_averaging_round.set()
                     gathered = {}
                     gathered = {}
 
 
                 self.finished_averaging_round.set()
                 self.finished_averaging_round.set()
@@ -510,11 +575,20 @@ class TrainingStateAverager(DecentralizedAverager):
     def _apply_averaging_results_(self):
     def _apply_averaging_results_(self):
         """Copy averaged tensors into their respective local tensors"""
         """Copy averaged tensors into their respective local tensors"""
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
         assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        if self.delta_rule_averaging and self._old_tensors is None:
+            logger.warning("Using delta_rule_averaging, but old tensors were not found. Averaging may have failed.")
         with self.get_tensors() as averaged_tensors:
         with self.get_tensors() as averaged_tensors:
             local_tensors = list(self._local_tensors())
             local_tensors = list(self._local_tensors())
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
             assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
-            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                local_tensor.copy_(averaged_tensor, non_blocking=True)
+            if not self.delta_rule_averaging or self._old_tensors is None:
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    local_tensor.copy_(averaged_tensor, non_blocking=True)
+            else:
+                assert len(self._old_tensors) == len(local_tensors)
+                for local_tensor, new_tensor, old_tensor in zip(local_tensors, averaged_tensors, self._old_tensors):
+                    delta = torch.sub(new_tensor, old_tensor, out=old_tensor)  # using old tensors as buffers
+                    local_tensor.add_(delta.to(device=local_tensor.device, dtype=local_tensor.dtype))
+
 
 
     def get_current_state(self):
     def get_current_state(self):
         """
         """

+ 1 - 0
hivemind/utils/__init__.py

@@ -7,4 +7,5 @@ from hivemind.utils.nested import *
 from hivemind.utils.networking import *
 from hivemind.utils.networking import *
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.serializer import MSGPackSerializer, SerializerBase
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
 from hivemind.utils.tensor_descr import BatchTensorDescriptor, TensorDescriptor
+from hivemind.utils.performance_ema import PerformanceEMA
 from hivemind.utils.timed_storage import *
 from hivemind.utils.timed_storage import *

+ 4 - 5
tests/test_optimizer.py

@@ -1,6 +1,5 @@
 import ctypes
 import ctypes
 import multiprocessing as mp
 import multiprocessing as mp
-import random
 import time
 import time
 from functools import partial
 from functools import partial
 
 
@@ -80,7 +79,7 @@ def test_grad_averager():
 @pytest.mark.forked
 @pytest.mark.forked
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
     "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
-    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+    [(False, False, False), (True, True, False), (True, False, False), (False, True, True), (True, False, True)],
 )
 )
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
 def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
     dht1 = hivemind.DHT(start=True)
     dht1 = hivemind.DHT(start=True)
@@ -137,8 +136,8 @@ def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
     avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
 
 
-    avgr1.step(wait_for_delayed_update=True)
-    avgr2.step(wait_for_delayed_update=True)
+    avgr1.step(wait_for_delayed_updates=True)
+    avgr2.step(wait_for_delayed_updates=True)
 
 
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
     assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
@@ -383,4 +382,4 @@ def test_optimizer(
     assert not optimizer.state_averager.is_alive()
     assert not optimizer.state_averager.is_alive()
     assert not optimizer.grad_averager.is_alive()
     assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
     assert not optimizer.tracker.is_alive()
-    assert optimizer.scheduled_round is None or optimizer.scheduled_round.done()
+    assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()