Pārlūkot izejas kodu

[hivemind.Optimizer] TrainingStateAverager (#407)

This PR implements a component of hivemind.Optimizer ( #400 ) that holds the training state and supports (delayed) optimizer steps and averaging rounds.
Unlike TrainingAverager, this class is does not need data locks as it will only update model parameters during .step.

Co-authored-by: Roman Zhytar <sir.roma2012@gmail.com>
Co-authored-by: Anton Sinitsin <ant.sinitsin@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 3 gadi atpakaļ
vecāks
revīzija
02eee9292d
2 mainītis faili ar 673 papildinājumiem un 0 dzēšanām
  1. 569 0
      hivemind/optim/experimental/state_averager.py
  2. 104 0
      tests/test_optimizer.py

+ 569 - 0
hivemind/optim/experimental/state_averager.py

@@ -0,0 +1,569 @@
+""" An extension of averager that supports common optimization use cases. """
+import logging
+from asyncio import Future
+from concurrent.futures import ThreadPoolExecutor
+from itertools import chain
+from threading import Event
+from typing import Any, Callable, Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union
+
+import torch
+
+import hivemind
+from hivemind import nested_compare
+from hivemind.averaging import DecentralizedAverager
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.utils import get_logger, nested_flatten, nested_map, nested_pack
+
+logger = get_logger(__name__)
+
+
+Parameters = Iterable[torch.Tensor]
+ParamGroups = Iterable[Dict[str, Any]]
+TorchOptimizer = torch.optim.Optimizer
+LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
+OptimizerFactory = Callable[[Union[Parameters, ParamGroups]], TorchOptimizer]
+SchedulerFactory = Callable[[TorchOptimizer], LRSchedulerBase]
+
+
+class TrainingStateAverager(DecentralizedAverager):
+    """
+    An auxiliary class that holds peer's training state, including model parameters, optimizer statistics, scheduler
+    and any other variables that define the local training state (e.g. batchnorm moving averages).
+    TrainingStateAveraager is intended to keep these parameters weakly synchronized across the swarm.
+
+    The intended use is to call .step(optimizer_step=..., averaging_round=...) periodically, e.g. after every batch.
+    If peer gets out of sync with the swarm, one should call state_averager.load_state_from_peers() to re-synchronize.
+
+    Example:
+
+    >>> avgr = TrainingStateAverager(optimizer=torch.optim.Adam, param_groups=model.parameters(), ...)
+    >>> # alternative interface: TrainingStateAverager(optimizer=torch.optim.Adam(model.parameters()), ...)
+    >>> avgr.load_state_from_peers()
+    >>> for i, batch in enumerate(training_dataloader):
+    >>>     loss = compute_loss(model, batch)
+    >>>     loss.backward()
+    >>>     avgr.step(optimizer_step=i % 10 == 0, averaging_round=is_it_time_for_averaging(), delay_averaging=True)
+
+    :note: when using delay_averaging or delay_optimizer_step, calling optimizer directly is not recommended because
+      it may overlap with delayed updates from a background thread with unpredictable results. Instead, please call
+      TrainingStateAverager.step(..., optimizer_step=True)
+
+    :param optimizer: PyTorch Optimizer or a callable that creates a optimizer from param groups
+    :param param_groups: optional, a list/tuple of parameters or structured param groups for the optimizer
+    :param scheduler: optional learning rate scheduler or callable that creates one from optimizer instance
+    :note: if provided, scheduler will be updated based on averager.local_epoch, not the number of step cycles
+    :param initialize_optimizer: if True, run a speculative optimizer step with zero gradients to initialize all
+      state tensors. If False, user must make sure that all tensors are pre-initialized at init.
+      By default, initialize optimizer unless it already has some state tensors to begin with.
+    :param offload_optimizer: if True, create optimizer on top of averaged parameters which may save device memory.
+    :param custom_gradients: if True, do *not* automatically load local gradients into the offloaded optimizer.
+      This assumes that offloaded gradients will be populated externally, e.g. by the user or by hivemind.Optimizer.
+    :param reuse_tensors: if True, reuse parameters and optimizer statistics as averaged_tensors for allreduce.
+      For this to work, all parameters must be on CPU and have the appropriate dtype for use in DecentralizedAverager
+    :param sync_epoch_when_averaging: if True, update local epoch to the latest epoch among averaging peers
+    :param parameter_names: optionally provide parameter names in the same order as param_groups
+    :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
+    :param extra_tensors: if specified, these extra tensors will also be averaged and shared in load_state_from_peers.
+    :note: you can use extra_tensors to for any tensors not used by the optimizer (e.g. batchnorm statistics)
+    :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
+    """
+
+    def __init__(
+        self,
+        *,
+        dht: hivemind.DHT,
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]] = None,
+        scheduler: Optional[Union[LRSchedulerBase, SchedulerFactory]] = None,
+        initialize_optimizer: Optional[bool] = None,
+        offload_optimizer: bool = False,
+        custom_gradients: bool = False,
+        reuse_tensors: bool = False,
+        sync_epoch_when_averaging: bool = False,
+        parameter_names: Optional[Sequence[str]] = None,
+        average_opt_statistics: Sequence[str] = (),
+        extra_tensors: Sequence[torch.Tensor] = (),
+        status_loglevel: int = logging.DEBUG,
+        **kwargs,
+    ):
+        average_opt_statistics = tuple(average_opt_statistics)
+        assert all(isinstance(key, str) for key in average_opt_statistics)
+        if offload_optimizer and reuse_tensors:
+            logger.warning("Setting offload_optimizer=True has no effect because reuse_parameters=True")
+        if custom_gradients and not offload_optimizer:
+            logger.warning("Setting custom_gradients=True has no effect because the optimizer is not offloaded")
+
+        param_groups, main_parameters, parameter_names = self._check_params(optimizer, param_groups, parameter_names)
+
+        self.status_loglevel = status_loglevel
+        self.reuse_tensors = reuse_tensors
+        self.offload_optimizer = offload_optimizer
+        self.custom_gradients = custom_gradients
+
+        self._main_parameters, self._parameter_names = main_parameters, parameter_names
+        self._averaged_parameters = tuple(map(self._make_host_tensor, main_parameters))
+        self.optimizer, self.scheduler = self._init_components(
+            param_groups, optimizer, scheduler, initialize_optimizer
+        )
+        self.opt_keys_for_averaging, self.extra_tensors = average_opt_statistics, extra_tensors
+        self.sync_epoch_when_averaging = sync_epoch_when_averaging
+        self.local_epoch = 0
+
+        self.step_executor = ThreadPoolExecutor(max_workers=1)
+        self.finished_optimizer_step = Event()
+        self.finished_averaging_round = Event()
+        self.pending_update = Future()
+        self.pending_update.set_result(None)
+
+        super().__init__(
+            dht=dht, averaged_tensors=self._init_averaged_tensors(), tensor_infos=self._init_tensor_infos(), **kwargs
+        )
+
+    @staticmethod
+    def _check_params(
+        optimizer: Union[TorchOptimizer, OptimizerFactory],
+        param_groups: Optional[Union[Parameters, ParamGroups]],
+        parameter_names: Optional[Sequence[str]],
+    ) -> Tuple[ParamGroups, Sequence[torch.Tensor], Sequence[str]]:
+        """Get and verify parameters, groups and names"""
+        if param_groups is None:
+            assert hasattr(optimizer, "param_groups"), "Must provide param_groups or an optimizer with .param_groups"
+            param_groups = optimizer.param_groups
+        param_groups = tuple(param_groups)
+        if all(isinstance(p, torch.Tensor) for p in param_groups):
+            param_groups = (dict(params=param_groups),)
+        for group in param_groups:
+            assert isinstance(group, dict) and group.get("params") is not None
+            assert all(isinstance(p, torch.Tensor) for p in group["params"])
+        parameters = tuple(chain(*(group["params"] for group in param_groups)))
+        if parameter_names is None:
+            parameter_names = tuple(i for i in range(len(parameters)))
+        parameter_names = tuple(nested_flatten(parameter_names))
+        assert len(parameters) == len(parameter_names), f"Expected {len(parameters)} names, got {len(parameter_names)}"
+        assert len(set(parameters)) == len(parameters), "Found duplicate parameters in param_groups"
+        return param_groups, parameters, parameter_names
+
+    def _make_host_tensor(self, source_tensor: torch.Tensor) -> torch.Tensor:
+        """Create a new tensor for averaging or reuse the existing one"""
+        if self.reuse_tensors:
+            assert source_tensor.device == torch.device("cpu") and source_tensor.dtype == torch.float32
+            if not source_tensor.is_shared():
+                source_tensor.share_memory_()
+            return source_tensor
+        else:
+            averaged_tensor = source_tensor.detach().to(device="cpu", dtype=torch.float32, copy=True)
+            return averaged_tensor.share_memory_().requires_grad_(source_tensor.requires_grad)
+
+    def _init_components(
+        self,
+        param_groups: ParamGroups,
+        optimizer_or_factory: Union[TorchOptimizer, OptimizerFactory],
+        scheduler_or_factory: Optional[Union[LRSchedulerBase, SchedulerFactory]],
+        initialize_optimizer: Optional[bool],
+    ) -> Tuple[TorchOptimizer, Optional[LRSchedulerBase]]:
+        """Get optimizer and scheduler by either instantiating user-provided factory or using pre-instantiated ones"""
+        assert hasattr(self, "_averaged_parameters"), "Internal error: must initialize averaged parameters first"
+        optimizer_is_factory = callable(optimizer_or_factory) and not isinstance(optimizer_or_factory, TorchOptimizer)
+        scheduler_is_factory = callable(scheduler_or_factory) and not isinstance(scheduler_or_factory, LRSchedulerBase)
+        if optimizer_is_factory and not scheduler_is_factory and scheduler_or_factory is not None:
+            raise ValueError("If optimizer is created internally, scheduler must also be initialized internally")
+        if self.offload_optimizer and not optimizer_is_factory:
+            raise ValueError("Using offload_optimizer requires creating optimizer inside hivemind")
+
+        # create optimizer
+        if optimizer_is_factory:
+            if self.offload_optimizer:
+                for param in self._averaged_parameters:
+                    if param.grad is None:
+                        param.grad = torch.zeros_like(param)
+
+                next_index = 0
+                param_groups_for_optimizer = []
+                for param_group in param_groups:
+                    num_params = len(param_group["params"])
+                    averaged_params_for_group = self._averaged_parameters[next_index : next_index + num_params]
+                    param_groups_for_optimizer.append(dict(param_group, params=averaged_params_for_group))
+                    next_index += num_params
+                assert next_index == len(self._averaged_parameters)
+
+            else:
+                param_groups_for_optimizer = param_groups
+            optimizer = optimizer_or_factory(param_groups_for_optimizer)
+        else:
+            optimizer = optimizer_or_factory
+
+        # optionally initialize optimizer state dict
+        if initialize_optimizer is None:
+            initialize_optimizer = not any(isinstance(x, torch.Tensor) for x in nested_flatten(optimizer.state_dict()))
+            logger.log(
+                self.status_loglevel,
+                "Initializing optimizer manually since it has no tensors in state dict"
+                "To override this, please provide initialize_optimizer=False",
+            )
+
+        if initialize_optimizer:
+            initialize_optimizer_state_(optimizer)  # note: this will run one optimizer step!
+
+        # create LR scheduler
+        if scheduler_is_factory:
+            assert callable(scheduler_or_factory)
+            scheduler = scheduler_or_factory(optimizer)
+        else:
+            scheduler = scheduler_or_factory
+
+        # verify optimizer and scheduler
+        assert isinstance(optimizer, TorchOptimizer) and len(optimizer.param_groups) == len(list(param_groups))
+        if self.offload_optimizer or self.reuse_tensors:
+            for param_group in optimizer.param_groups:
+                for param in param_group["params"]:
+                    assert param.is_shared()
+        assert isinstance(scheduler, (LRSchedulerBase, type(None)))
+        if scheduler is not None:
+            assert scheduler.optimizer == optimizer
+        return optimizer, scheduler
+
+    def _local_tensors(self) -> Iterator[torch.Tensor]:
+        """Iterate local trainer's tensors that should be averaged with peers"""
+        for param_group in self.optimizer.param_groups:
+            yield from param_group["params"]
+        for stats in self.opt_keys_for_averaging:
+            for param_group in self.optimizer.param_groups:
+                for param in param_group["params"]:
+                    yield self.optimizer.state[param][stats]
+        yield from self.extra_tensors
+
+    @torch.no_grad()
+    def _init_averaged_tensors(self) -> Sequence[torch.Tensor]:
+        """Create or reuse a tuple of all averaged tensors, including parameters, optimizer statistics and extras"""
+        assert hasattr(self, "optimizer"), "Optimizer should already be initialized by this point"
+        assert hasattr(self, "_averaged_parameters"), "Should initialize _averaged_parameters first"
+        assert not hasattr(self, "_averaged_tensors"), "Averager is already initialized"
+        assert all(isinstance(key, str) for key in self.opt_keys_for_averaging)
+
+        local_tensors = tuple(self._local_tensors())
+        local_non_parameters = local_tensors[len(self._averaged_parameters) :]
+        averaged_tensors = tuple(map(torch.Tensor.detach, self._averaged_parameters))
+        averaged_non_parameters = tuple(map(self._make_host_tensor, local_non_parameters))
+        averaged_tensors = tuple(chain(averaged_tensors, averaged_non_parameters))
+
+        assert len(averaged_tensors) == len(local_tensors)
+        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+            assert local_tensor.shape == averaged_tensor.shape
+            if averaged_tensor.grad is not None:
+                logger.debug(self.status_loglevel, "setting gradients for averaged tensor to None")
+
+        return averaged_tensors
+
+    def _init_tensor_infos(self) -> Sequence[CompressionInfo]:
+        """Get CompressionInfo for each state tensor, accounting for its role and specification"""
+        tensor_infos = []
+        for param, param_name in zip(self._main_parameters, self._parameter_names):
+            tensor_infos.append(CompressionInfo.from_tensor(param, key=param_name, role=TensorRole.PARAMETER))
+        for stats_name in self.opt_keys_for_averaging:
+            opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(opt_parameters) == len(self._parameter_names)
+            for param, param_name in zip(opt_parameters, self._parameter_names):
+                tensor_infos.append(
+                    CompressionInfo.from_tensor(
+                        self.optimizer.state[param][stats_name],
+                        key=(param_name, stats_name),
+                        role=TensorRole.OPTIMIZER,
+                    )
+                )
+        for i, extra_tensor in enumerate(self.extra_tensors):
+            tensor_infos.append(CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED))
+        return tuple(tensor_infos)
+
+    def step(
+        self,
+        wait_for_delayed_update: bool = None,
+        apply_delayed_updates: bool = True,
+        increment_epoch: bool = False,
+        optimizer_step: bool = False,
+        zero_grad: bool = False,
+        delay_optimizer_step: bool = False,
+        averaging_round: bool = False,
+        delay_averaging: Optional[bool] = None,
+        averaging_kwargs: Optional[Dict[str, Any]] = None,
+    ):
+        """
+        Perform one or several possible actions, depending on the specified keyword args.
+        The actions will be performed in the same order as specified below:
+
+        :param wait_for_delayed_update: if there are background averaging rounds, wait for them to finish
+          by default, await delayed updates when scheduling the next optimizer step, otherwise do not update
+        :param apply_delayed_updates: apply any averaging rounds that have finished but were not applied yet
+        :param increment_epoch: increment .local_epoch and update the learning rate scheduler (if present)
+        :param optimizer_step: perform a single optimizer step and update local parameters (without changing scheduler)
+        :param zero_grad: if True, reset local gradients after performing optimizer step
+        :param delay_optimizer_step: if True, run optimizer step in background and apply results in a future step
+        :param averaging_round: average parameters, chosen optimizer keys and extra tensors with a group of peers
+        :param delay_averaging: if True, perform averaging in background and apply results in a future step
+          by default, delay averaging if the optimizer step is also delayed. Set to true to delay only this phase.
+        :param averaging_kwargs: a dict of keyword arguments forwarded into averaging round
+        """
+        if delay_averaging is None:
+            delay_averaging = delay_optimizer_step
+        if wait_for_delayed_update is None:
+            wait_for_delayed_update = optimizer_step or zero_grad or averaging_round
+        assert not delay_optimizer_step or delay_averaging, "Delayed optimizer step requires delayed averaging"
+        if optimizer_step or averaging_round or zero_grad:
+            assert wait_for_delayed_update, "Must wait for background updates to finish before scheduling new ones"
+        if delay_optimizer_step:
+            assert self.offload_optimizer, "Delayed optimizer step is only available with offload_optimizer"
+            assert not averaging_round or delay_averaging, "Averaging after delayed optimizer should also be delayed"
+        if averaging_kwargs and not averaging_round:
+            logger.warning(f"Averaging parameters not used because averaging_round=False: {averaging_kwargs}")
+        output = None
+
+        if wait_for_delayed_update:
+            if not self.pending_update.done():
+                logger.log(self.status_loglevel, "Waiting for delayed updates to finish...")
+                output = self.pending_update.result()
+
+        if self.pending_update.done() and self.pending_update.exception():
+            logger.warning(f"Background update failed with {self.pending_update.exception()} and will be ignored")
+
+        if apply_delayed_updates:
+            if self.finished_averaging_round.is_set():
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Received results from background averaging round")
+                self.finished_averaging_round.clear()
+
+            if self.finished_optimizer_step.is_set():
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Received results from background optimizer step")
+                self.finished_optimizer_step.clear()
+
+        if increment_epoch:
+            self.local_epoch += 1
+            logger.log(self.status_loglevel, f"Switching to epoch {self.local_epoch}")
+            self._update_scheduler()
+
+        if optimizer_step or zero_grad or averaging_round:
+            assert self.pending_update.done(), "Tried to perform a new update but previous update is still running"
+
+            if self.offload_optimizer and not self.custom_gradients:
+                self._load_local_grads_into_optimizer_()
+
+            self.pending_update = self.step_executor.submit(
+                self._do,
+                optimizer_step,
+                zero_grad,
+                averaging_round,
+                **averaging_kwargs or {},
+            )
+
+            if (optimizer_step or zero_grad) and not delay_optimizer_step:
+                self.finished_optimizer_step.wait()
+                self.finished_optimizer_step.clear()
+                if self.offload_optimizer:
+                    self._apply_optimizer_results_()
+                logger.log(self.status_loglevel, "Finished optimizer step")
+
+            if averaging_round and not delay_averaging:
+                self.finished_averaging_round.wait()
+                self.finished_averaging_round.clear()
+                if not self.reuse_tensors:
+                    self._apply_averaging_results_()
+                logger.log(self.status_loglevel, "Finished averaging round")
+
+            if not delay_averaging:
+                try:
+                    output = self.pending_update.result()
+                finally:
+                    self.finished_averaging_round.clear()
+                    self.finished_optimizer_step.clear()
+        return output
+
+    def _do(self, optimizer_step: bool, zero_grad: bool, averaging_round: bool, **kwargs):
+        """
+        Run the optimizer step, followed by a scheduler step and an averaging round, each stage is optional.
+        This method is meant to be called in the background executor.
+        """
+        try:
+            if optimizer_step:
+                logger.log(self.status_loglevel, f"Running optimizer step")
+                self.optimizer.step()
+            if zero_grad:
+                logger.log(self.status_loglevel, f"Running zero grad")
+                self.optimizer.zero_grad()
+                if self.offload_optimizer:
+                    for parameter in self._main_parameters:
+                        if parameter.grad is not None:
+                            parameter.grad.zero_()
+
+            self.finished_optimizer_step.set()
+
+            if averaging_round:
+                if not self.reuse_tensors:
+                    self._load_local_tensors_into_averager_()
+                try:
+                    gathered = super().step(gather=self.local_epoch, **kwargs)
+                    logger.log(self.status_loglevel, f"Averaged parameters with {len(gathered)} peers")
+                except BaseException as e:
+                    logger.log(self.status_loglevel, f"Averaging failed with {type(e)}")
+                    self.finished_averaging_round.set()
+                    gathered = {}
+
+                self.finished_averaging_round.set()
+
+                if self.sync_epoch_when_averaging:
+                    old_epoch = self.local_epoch
+                    for peer_epoch in gathered.values():
+                        self.local_epoch = max(self.local_epoch, peer_epoch)
+                    if self.local_epoch != old_epoch:
+                        logger.log(self.status_loglevel, f"Found peer with newer epoch ({self.local_epoch})")
+                        self._update_scheduler()
+
+        except Exception as e:
+            logger.exception(e)
+            self.finished_optimizer_step.set()
+            self.finished_averaging_round.set()
+
+    @torch.no_grad()
+    def _load_local_grads_into_optimizer_(self):
+        """Copy local gradients into the gradient buffers of the offloaded optimizer"""
+        assert self.offload_optimizer, "Loading into offloaded optimizer requires using offloaded optimizer"
+        opt_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+        for main_param, opt_param in zip(self._main_parameters, opt_parameters):
+            if main_param.grad is not None:
+                opt_param.grad.copy_(main_param.grad, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_optimizer_results_(self):
+        """Copy parameters from offloaded optimizer to the main model"""
+        assert self.offload_optimizer, "Applying offloaded optimizer updates requires offloaded optimizer"
+        with self.lock_averaged_tensors:
+            offloaded_parameters = [param for group in self.optimizer.param_groups for param in group["params"]]
+            assert len(offloaded_parameters) == len(self._main_parameters), "opt parameters changed during training"
+            for main_param, offloaded_param in zip(self._main_parameters, offloaded_parameters):
+                main_param.copy_(offloaded_param, non_blocking=True)
+
+    @torch.no_grad()
+    def _load_local_tensors_into_averager_(self):
+        """Copy local tensors into the averaging buffers"""
+        assert not self.reuse_tensors, "No need to load tensors into averager: both tensors share the same memory"
+        with self.get_tensors() as averaged_tensors:
+            for local_tensor, averaged_tensor in zip(self._local_tensors(), averaged_tensors):
+                averaged_tensor.copy_(local_tensor, non_blocking=True)
+
+    @torch.no_grad()
+    def _apply_averaging_results_(self):
+        """Copy averaged tensors into their respective local tensors"""
+        assert not self.reuse_tensors, "No need to update averaged tensors since they reuse the same memory"
+        with self.get_tensors() as averaged_tensors:
+            local_tensors = list(self._local_tensors())
+            assert len(local_tensors) == len(averaged_tensors), "Tensor structure changed during training"
+            for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                local_tensor.copy_(averaged_tensor, non_blocking=True)
+
+    def get_current_state(self):
+        """
+        Get current model/optimizer state and when requested by a newbie peer. executed in the host process.
+        :returns: a tuple of (serializable_small_metadata, sequence of torch tensors)
+        """
+        with torch.no_grad():
+            optimized_parameters = tuple(
+                param.detach().cpu() for param_group in self.optimizer.param_groups for param in param_group["params"]
+            )
+            parameter_infos = [
+                CompressionInfo.from_tensor(param, key=key, role=TensorRole.PARAMETER)
+                for param, key in zip(optimized_parameters, self._parameter_names)
+            ]
+            extra_tensors = tuple(tensor.detach().cpu() for tensor in self.extra_tensors)
+            extra_infos = [
+                CompressionInfo.from_tensor(extra_tensor, key=i, role=TensorRole.UNSPECIFIED)
+                for i, extra_tensor in enumerate(extra_tensors)
+            ]
+            optimizer_metadata, optimizer_tensors = dump_optimizer_state(self.optimizer)
+            optimizer_infos = [
+                CompressionInfo.from_tensor(opt_tensor, key=i, role=TensorRole.OPTIMIZER)
+                for i, opt_tensor in enumerate(optimizer_tensors)
+            ]
+
+        metadata = dict(
+            epoch=self.local_epoch, group_bits=self.get_group_bits(), optimizer_metadata=optimizer_metadata
+        )
+        all_tensors = list(chain(optimized_parameters, extra_tensors, optimizer_tensors))
+        all_tensor_infos = list(chain(parameter_infos, extra_infos, optimizer_infos))
+        return metadata, all_tensors, all_tensor_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update trainer parameters/statistics.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        parameters_and_extras = tuple(chain(self._main_parameters, self.extra_tensors))
+        num_parameters_and_extras = len(parameters_and_extras)
+
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        if (not isinstance(metadata.get("epoch"), int)) or metadata["epoch"] < self.local_epoch:
+            logger.warning("Cowardly refusing to load state from peer: peer's epoch is behind our local epoch")
+            return
+
+        loaded_parameters_and_extras = flat_tensors[:num_parameters_and_extras]
+        loaded_opt_tensors = flat_tensors[num_parameters_and_extras:]
+        if num_parameters_and_extras != len(loaded_parameters_and_extras):
+            logger.error("Failed to load state from peer, received parameters, extras or metadata.")
+            return
+
+        try:
+            load_optimizer_state(self.optimizer, metadata["optimizer_metadata"], loaded_opt_tensors)
+        except StopIteration:
+            logger.warning("Failed to load state from peer, received inconsistent number of optimizer statistics")
+            return
+
+        with torch.no_grad():
+            for local_param, loaded_param in zip(parameters_and_extras, loaded_parameters_and_extras):
+                local_param.copy_(loaded_param, non_blocking=True)
+        self.local_epoch = metadata["epoch"]
+        self._update_scheduler()
+
+    def _update_scheduler(self):
+        """Increase the scheduler state until it becomes synchronized with local epoch"""
+        if self.scheduler:
+            while self.scheduler._step_count <= self.local_epoch:
+                self.scheduler.step()
+
+
+def initialize_optimizer_state_(opt: torch.optim.Optimizer):
+    """Initialize optimizer statistics by running a virtual optimizer step with zero gradients"""
+    flat_params = tuple(param for group in opt.param_groups for param in group["params"])
+    old_grads = []
+    for param in flat_params:
+        old_grads.append(param.grad)
+        param.grad = torch.zeros_like(param)
+    opt.step()
+    for param, old_grad in zip(flat_params, old_grads):
+        param.grad = old_grad
+
+
+def dump_optimizer_state(opt: torch.optim.Optimizer):
+    """Convert optimizer state into a format of DecentralizedAverager's get_current_state/load_state_from_peers"""
+    with torch.no_grad():
+        flat_metadata, flat_tensors = [], []
+        for elem in nested_flatten(opt.state_dict()):
+            if isinstance(elem, torch.Tensor):
+                flat_metadata.append(dict(type="tensor", index=len(flat_tensors)))
+                flat_tensors.append(elem.cpu())
+            else:
+                flat_metadata.append(dict(type="value", value=elem))
+        return flat_metadata, flat_tensors
+
+
+def load_optimizer_state(optimizer: torch.optim.Optimizer, flat_metadata: Dict, flat_tensors: Sequence[torch.Tensor]):
+    """Load a state obtained by dump_optimizer_state back into the optimizer"""
+    flat_optimizer_state = []
+    for elem in flat_metadata:
+        if elem.get("type") == "tensor" and isinstance(elem.get("index"), int):
+            flat_optimizer_state.append(flat_tensors[elem["index"]])
+        elif elem.get("type") == "value" and "value" in elem:
+            flat_optimizer_state.append(elem["value"])
+    return optimizer.load_state_dict(nested_pack(flat_optimizer_state, structure=optimizer.state_dict()))

+ 104 - 0
tests/test_optimizer.py

@@ -1,5 +1,7 @@
 import time
+from functools import partial
 
+import numpy as np
 import pytest
 import torch
 import torch.nn as nn
@@ -8,6 +10,7 @@ import torch.nn.functional as F
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.optim.experimental.grad_averager import GradientAverager
+from hivemind.optim.experimental.state_averager import TrainingStateAverager
 
 
 @pytest.mark.forked
@@ -66,3 +69,104 @@ def test_grad_averager():
     # after no longer use_averaged_gradients
     assert not torch.allclose(model1.w.grad, ref_average)
     assert not torch.allclose(model2.w.grad, ref_average)
+
+
+@pytest.mark.forked
+@pytest.mark.parametrize(
+    "offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
+    [(False, False, False), (True, False, False), (False, True, True), (True, False, True)],
+)
+def test_state_averager(offload_optimizer: bool, reuse_tensors: bool, sync_epoch_when_averaging: bool):
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    torch.manual_seed(1337)
+    torch.use_deterministic_algorithms(True)
+    # note: use_deterministic_algorithms does not affect further tests because this test is forked
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    extras1 = (torch.randn(2, 2), -torch.rand(1))
+    extras2 = (-torch.randn(2, 2), torch.rand(1))
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.Adam, lr=0.1, betas=(0.9, 0.9)),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        sync_epoch_when_averaging=sync_epoch_when_averaging,
+        average_opt_statistics=("exp_avg_sq",),
+        offload_optimizer=offload_optimizer,
+        reuse_tensors=reuse_tensors,
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), extra_tensors=extras1, start=True, **common_kwargs
+    )
+    avgr2 = TrainingStateAverager(
+        dht=dht2, param_groups=model2.parameters(), extra_tensors=extras2, start=True, **common_kwargs
+    )
+
+    x = torch.ones(2)
+
+    for step in range(20):
+        F.mse_loss(model1(x), torch.ones(3)).mul(2).backward()
+        avgr1.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=True)
+
+        F.mse_loss(model2(x), -torch.ones(3)).backward()
+        avgr2.step(optimizer_step=True, zero_grad=True, averaging_round=(step == 10), delay_averaging=False)
+
+    assert torch.all(model1.weight.grad == 0) and torch.all(model2.weight.grad == 0), "zero grad did not trigger"
+    assert model1(x).mean() > 0.5 and model2(x).mean() < -0.5, "models did not train properly"
+    assert torch.allclose(extras1[0], extras2[0]), "first extra tensors were not averaged"
+    assert torch.allclose(extras1[1], extras2[1]), "second extra tensors were not averaged"
+
+    stats1 = avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    stats2 = avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"].clone()
+    assert not torch.allclose(stats1, stats2)
+
+    avgr1.step(increment_epoch=True)
+
+    avgr1.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+    avgr2.step(increment_epoch=True, averaging_round=True, delay_averaging=True)
+
+    avgr1.step(wait_for_delayed_update=True)
+    avgr2.step(wait_for_delayed_update=True)
+
+    assert torch.allclose(model1(x), model2(x)), "model parameters were not averaged correctly"
+    assert torch.allclose(avgr1.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert torch.allclose(avgr2.optimizer.state_dict()["state"][0]["exp_avg_sq"], (stats1 + stats2) / 2)
+    assert avgr1.local_epoch == 2
+    assert avgr2.local_epoch == (2 if sync_epoch_when_averaging else 1)
+
+
+@pytest.mark.forked
+def test_load_state_from_peers():
+    dht1 = hivemind.DHT(start=True)
+    dht2 = hivemind.DHT(initial_peers=dht1.get_visible_maddrs(), start=True)
+
+    model1 = nn.Linear(2, 3)
+    model2 = nn.Linear(2, 3)
+
+    common_kwargs = dict(
+        optimizer=partial(torch.optim.SGD, lr=0.1),
+        scheduler=partial(torch.optim.lr_scheduler.LambdaLR, lr_lambda=lambda t: 1.0 / max(1, t)),
+        target_group_size=2,
+        prefix="my_exp",
+    )
+
+    avgr1 = TrainingStateAverager(
+        dht=dht1, param_groups=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+    )
+
+    avgr2 = TrainingStateAverager(dht=dht2, param_groups=model2.parameters(), start=True, **common_kwargs)
+
+    avgr2.local_epoch = 1337
+    model2.weight.data[...] = 42
+    time.sleep(0.1)
+
+    avgr1.load_state_from_peers()
+    assert avgr1.local_epoch == 1337
+    assert torch.all(model1.weight == 42).item()
+    assert np.allclose(avgr1.optimizer.param_groups[0]["lr"], 0.1 / 1337)