Przeglądaj źródła

Remove deprecated code in hivemind.optim and hivemind.averaging before the 1.1.0 release (#480)

* Deprecate averaging_expiration

* Remove DecentralizedOptimizerBase with subclasses

* Update docs
Max Ryabinin 3 lat temu
rodzic
commit
4a3d8fb843

+ 5 - 6
benchmarks/benchmark_averaging.py

@@ -6,10 +6,9 @@ import time
 import torch
 
 import hivemind
-from hivemind.proto import runtime_pb2
+from hivemind.compression import Float16Compression
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -38,7 +37,7 @@ def benchmark_averaging(
     num_peers: int,
     target_group_size: int,
     num_rounds: int,
-    averaging_expiration: float,
+    min_matchmaking_time: float,
     request_timeout: float,
     round_timeout: float,
     hid_size: int,
@@ -64,9 +63,9 @@ def benchmark_averaging(
             dht,
             prefix="my_tensor",
             initial_group_bits=initial_bits,
-            compression_type=runtime_pb2.CompressionType.FLOAT16,
+            compression=Float16Compression(),
             target_group_size=target_group_size,
-            averaging_expiration=averaging_expiration,
+            min_matchmaking_time=min_matchmaking_time,
             request_timeout=request_timeout,
             start=True,
         )
@@ -108,7 +107,7 @@ if __name__ == "__main__":
     parser.add_argument("--num_rounds", type=int, default=5, required=False)
     parser.add_argument("--hid_size", type=int, default=256, required=False)
     parser.add_argument("--num_layers", type=int, default=3, required=False)
-    parser.add_argument("--averaging_expiration", type=float, default=5, required=False)
+    parser.add_argument("--min_matchmaking_time", type=float, default=5, required=False)
     parser.add_argument("--round_timeout", type=float, default=15, required=False)
     parser.add_argument("--request_timeout", type=float, default=1, required=False)
     parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)

+ 1 - 17
docs/modules/optim.rst

@@ -21,20 +21,4 @@
 
 .. currentmodule:: hivemind.optim.grad_scaler
 .. autoclass:: GradScaler
-   :member-order: bysource
-
-
-**CollaborativeOptimizer**
---------------------------
-
-
-.. automodule:: hivemind.optim.collaborative
-.. currentmodule:: hivemind.optim
-
-.. autoclass:: CollaborativeOptimizer
-   :members: step
-   :member-order: bysource
-
-.. autoclass:: CollaborativeAdaptiveOptimizer
-   :members:
-   :member-order: bysource
+   :member-order: bysource

+ 3 - 3
examples/albert/README.md

@@ -3,7 +3,7 @@
 This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the
 WikiText103 dataset. It uses Hugging Face [datasets](https://github.com/huggingface/datasets)
 and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates,
-using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+using `hivemind.Optimizer` to exchange information between peers.
 
 ## Preparation
 
@@ -143,7 +143,7 @@ you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
 incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example
-below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
+below). In case of high network latency, you may want to increase `--matchmaking_time` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
 recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
@@ -182,7 +182,7 @@ Here's an example of a full trainer script for Google Colab:
 !ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
     --initial_peers ONE_OR_MORE_PEERS \
     --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
-    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
+    --client_mode --matchmaking_time 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 
 ### Using IPFS

+ 1 - 11
hivemind/__init__.py

@@ -9,17 +9,7 @@ from hivemind.moe import (
     Server,
     register_expert_class,
 )
-from hivemind.optim import (
-    CollaborativeAdaptiveOptimizer,
-    CollaborativeOptimizer,
-    DecentralizedAdam,
-    DecentralizedOptimizer,
-    DecentralizedOptimizerBase,
-    DecentralizedSGD,
-    GradScaler,
-    Optimizer,
-    TrainingAverager,
-)
+from hivemind.optim import GradScaler, Optimizer, TrainingAverager
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 

+ 0 - 6
hivemind/averaging/averager.py

@@ -117,7 +117,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         initial_group_bits: str = "",
-        averaging_expiration: Optional[float] = None,
         min_matchmaking_time: float = 5.0,
         request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
@@ -145,11 +144,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
-        if averaging_expiration is not None:
-            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
-            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
-            min_matchmaking_time = averaging_expiration
-
         super().__init__()
         self.dht = dht
         self.prefix = prefix

+ 0 - 4
hivemind/optim/__init__.py

@@ -1,7 +1,3 @@
-from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.optimizer import Optimizer
-from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager

+ 0 - 34
hivemind/optim/adaptive.py

@@ -1,34 +0,0 @@
-from typing import Sequence
-
-import torch.optim
-
-from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.training_averager import TrainingAverager
-
-
-class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
-    """
-    Behaves exactly as CollaborativeOptimizer except:
-
-    * averages adaptive learning rates of an optimizer
-    * doesn't average gradients
-
-    :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
-    :param kwargs: options for CollaborativeOptimizer
-    """
-
-    def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
-        super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
-
-    def _make_averager(self, average_opt_statistics, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=average_opt_statistics,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )

+ 0 - 44
hivemind/optim/base.py

@@ -1,44 +0,0 @@
-from warnings import warn
-
-import torch
-
-from hivemind.dht import DHT
-
-
-class DecentralizedOptimizerBase(torch.optim.Optimizer):
-    """A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model"""
-
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
-        self.opt, self.dht = opt, dht
-        warn(
-            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
-            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
-            FutureWarning,
-            stacklevel=2,
-        )
-
-    @property
-    def state(self):
-        return self.opt.state
-
-    @property
-    def param_groups(self):
-        return self.opt.param_groups
-
-    def add_param_group(self, param_group: dict) -> None:
-        raise ValueError(
-            f"{self.__class__.__name__} does not support calling add_param_group after creation."
-            f"Please provide all parameter groups at init."
-        )
-
-    def state_dict(self) -> dict:
-        return self.opt.state_dict()
-
-    def load_state_dict(self, state_dict: dict):
-        return self.opt.load_state_dict(state_dict)
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
-
-    def shutdown(self):
-        raise NotImplementedError()

+ 0 - 558
hivemind/optim/collaborative.py

@@ -1,558 +0,0 @@
-from __future__ import annotations
-
-import logging
-from dataclasses import dataclass
-from threading import Event, Lock, Thread
-from typing import Dict, Iterator, Optional
-
-import numpy as np
-import torch
-from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
-
-from hivemind.dht import DHT
-from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-from hivemind.utils.performance_ema import PerformanceEMA
-
-logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
-
-@dataclass(frozen=False)
-class CollaborationState:
-    optimizer_step: int
-    samples_accumulated: int
-    target_batch_size: int
-    num_peers: int
-    num_clients: int
-    eta_next_step: float
-    next_fetch_time: float
-
-    @property
-    def ready_for_step(self):
-        return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
-
-    def register_step(self, local_step: int):
-        self.optimizer_step = max(local_step, self.optimizer_step)
-        self.samples_accumulated = 0
-        self.eta_next_step = float("inf")
-
-
-class TrainingState(BaseModel):
-    peer_id: bytes
-    step: conint(ge=0, strict=True)
-    samples_accumulated: conint(ge=0, strict=True)
-    samples_per_second: confloat(ge=0.0, strict=True)
-    time: StrictFloat
-    client_mode: StrictBool
-
-
-class TrainingProgressSchema(BaseModel):
-    progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
-
-
-class CollaborativeOptimizer(DecentralizedOptimizerBase):
-    """
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
-
-    These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
-    Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
-
-    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
-      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
-      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
-
-    :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
-
-      * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
-      * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
-
-
-    :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
-    :param dht: a running hivemind.DHT daemon connected to other peers
-    :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
-    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
-    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
-    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
-    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
-    :param expected_drift_peers: assume that this many new peers can join between steps
-    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
-    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
-      refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
-    :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
-    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
-    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
-    :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
-    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
-    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
-    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
-    :param scheduler: if specified, use this scheduler to update optimizer learning rate
-    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
-      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
-    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
-     device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
-     the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
-    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
-    :param kwargs: additional parameters forwarded to DecentralizedAverager
-    :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
-      explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_batch_size: int,
-        batch_size_per_step: Optional[int] = None,
-        scheduler: Optional[LRSchedulerBase] = None,
-        min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
-        default_refresh_period: float = 3,
-        expected_drift_peers: float = 3,
-        expected_drift_rate: float = 0.2,
-        performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 60.0,
-        averaging_timeout: Optional[float] = None,
-        load_state_timeout: float = 600.0,
-        step_tolerance: int = 1,
-        reuse_grad_buffers: bool = False,
-        accumulate_grads_on: Optional[torch.device] = None,
-        client_mode: bool = False,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-
-        signature_validator = RSASignatureValidator()
-        self._local_public_key = signature_validator.local_public_key
-        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
-
-        if reuse_grad_buffers and accumulate_grads_on is not None:
-            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
-        self.prefix, self.scheduler = prefix, scheduler
-        self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
-        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
-            min_refresh_period,
-            max_refresh_period,
-            default_refresh_period,
-        )
-        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout = averaging_timeout
-        self.load_state_timeout = load_state_timeout
-        self.metadata_expiration = metadata_expiration
-        self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
-        self.client_mode, self.step_tolerance = client_mode, step_tolerance
-        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
-        self.averager = self._make_averager(**kwargs)
-
-        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
-
-        self.training_progress_key = f"{self.prefix}_progress"
-        self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
-        self.last_step_time = None
-
-        self.collaboration_state = self._fetch_state()
-        self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
-        self.lock_local_progress, self.should_report_progress = Lock(), Event()
-        self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
-        self.progress_reporter.start()
-        self.collaboration_state_updater = Thread(
-            target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
-        )
-        self.collaboration_state_updater.start()
-
-    def _make_averager(self, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=True,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )
-
-    @property
-    def local_step(self) -> int:
-        return self.averager.local_step
-
-    @property
-    def is_synchronized(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step
-
-    @property
-    def is_within_tolerance(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
-
-    def is_alive(self) -> bool:
-        return self.averager.is_alive()
-
-    def load_state_from_peers(self, **kwargs):
-        """Attempt to fetch the newest collaboration state from other peers"""
-        with self.lock_collaboration_state:
-            while True:
-                try:
-                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
-                    break
-                except KeyboardInterrupt:
-                    raise
-                except BaseException as e:
-                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
-                    continue
-
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.update_scheduler()
-
-    def state_dict(self) -> dict:
-        state_dict = super().state_dict()
-        state_dict["state"]["collaborative_step"] = self.local_step
-        return state_dict
-
-    def load_state_dict(self, state_dict: dict):
-        if "collaborative_step" in state_dict["state"]:
-            self.averager.local_step = state_dict["state"].pop("collaborative_step")
-        return super().load_state_dict(state_dict)
-
-    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
-        """
-        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
-
-        :param batch_size: optional override for batch_size_per_step from init
-        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
-            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
-        if self.batch_size_per_step is None:
-            if batch_size is None:
-                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
-            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
-            self.batch_size_per_step = batch_size
-        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
-
-        if not self.is_synchronized and not self.is_within_tolerance:
-            logger.log(self.status_loglevel, "Peer is out of sync")
-            self.load_state_from_peers()
-            return
-        elif not self.is_synchronized and self.is_within_tolerance:
-            self.averager.local_step = self.collaboration_state.optimizer_step
-            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
-
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
-            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.should_report_progress.set()
-            return
-
-        if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
-            logger.warning(
-                f"Training step took {get_dht_time() - self.last_step_time}, "
-                f"but metadata expired in {self.metadata_expiration} s."
-            )
-
-        self.accumulate_grads_(batch_size)
-
-        with self.lock_local_progress:
-            self.local_samples_accumulated += batch_size
-            self.local_updates_accumulated += 1
-            self.performance_ema.update(task_size=batch_size)
-            self.should_report_progress.set()
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        with self.performance_ema.pause(), self.lock_collaboration_state:
-            self.collaboration_state = self._fetch_state()
-            self.collaboration_state_updated.set()
-
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self)
-
-            current_step, group_info = self.averager.local_step, None
-
-            if self.collaboration_state.num_peers > 1:
-                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
-                weight = self.local_samples_accumulated / mean_samples_per_worker
-                try:
-                    group_info = self.averager.step(
-                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
-                    )
-                    if group_info:
-                        logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                        # update our current step if we averaged with another peer that was at a more recent step
-                        for peer, peer_step in group_info.items():
-                            if isinstance(peer_step, int):
-                                current_step = max(current_step, peer_step)
-                            else:
-                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            else:
-                logger.log(
-                    self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
-                )
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self)
-            else:
-                self.opt.step()
-
-            self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-            self.update_scheduler()
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = self.local_step
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def step_aux(self, **kwargs):
-        """
-        Find and assist other peers in averaging without sending local gradients.
-
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        with self.lock_collaboration_state:
-            current_step, group_info = self.averager.local_step, None
-
-            try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
-                if group_info:
-                    logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                    # update our current step if we averaged with another peer that was at a more recent step
-                    for peer, peer_step in group_info.items():
-                        if isinstance(peer_step, int):
-                            current_step = max(current_step, peer_step)
-                        else:
-                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-            except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def _grad_buffers(self) -> Iterator[torch.Tensor]:
-        """pytorch-internal gradient buffers"""
-        for param_group in self.opt.param_groups:
-            for param in param_group["params"]:
-                if param.grad is None:
-                    yield torch.zeros_like(param)
-                else:
-                    yield param.grad
-
-    @torch.no_grad()
-    def accumulated_grads(self) -> Iterator[torch.Tensor]:
-        """local gradient accumulators"""
-        if self.reuse_grad_buffers:
-            yield from self._grad_buffers()
-            return
-
-        if self._grads is None:
-            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
-        yield from self._grads
-
-    @torch.no_grad()
-    def accumulate_grads_(self, batch_size: int):
-        """add current gradients to grad accumulators (if any)"""
-        if self.reuse_grad_buffers:
-            # user is responsible for accumulating gradients in .grad buffers
-            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
-        else:
-            alpha = float(batch_size) / self.batch_size_per_step
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
-
-    @torch.no_grad()
-    def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if not self.reuse_grad_buffers:
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
-        if scale_by is not None:
-            for grad_buf in self._grad_buffers():
-                grad_buf.mul_(scale_by)
-
-    @torch.no_grad()
-    def reset_accumulated_grads_(self):
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
-
-    def report_training_progress(self):
-        """Periodically publish metadata and the current number of samples accumulated towards the next step"""
-        while self.is_alive():
-            self.should_report_progress.wait()
-            self.should_report_progress.clear()
-            with self.lock_local_progress:
-                current_time = get_dht_time()
-                local_state_info = TrainingState(
-                    peer_id=self.averager.peer_id.to_bytes(),
-                    step=self.local_step,
-                    samples_accumulated=self.local_samples_accumulated,
-                    samples_per_second=self.performance_ema.samples_per_second,
-                    time=current_time,
-                    client_mode=self.averager.client_mode,
-                )
-
-            self.dht.store(
-                key=self.training_progress_key,
-                subkey=self._local_public_key,
-                value=local_state_info.dict(),
-                expiration_time=current_time + self.metadata_expiration,
-                return_future=True,
-            )
-
-    def check_collaboration_state_periodically(self):
-        """
-        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
-        """
-        while self.is_alive():
-            time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
-            if self.collaboration_state_updated.wait(time_to_next_update):
-                self.collaboration_state_updated.clear()
-                continue  # if state was updated externally, reset timer
-
-            with self.lock_collaboration_state:
-                self.collaboration_state = self._fetch_state()
-
-    def _fetch_state(self) -> CollaborationState:
-        """Read performance statistics reported by peers, estimate progress towards next batch"""
-        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
-        current_time = get_dht_time()
-
-        if not isinstance(response, dict) or len(response) == 0:
-            logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
-            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
-
-            return CollaborationState(
-                self.local_step,
-                self.local_samples_accumulated,
-                self.target_batch_size,
-                num_peers=0,
-                num_clients=0,
-                eta_next_step=current_time + local_eta_next_step,
-                next_fetch_time=current_time + self.default_refresh_period,
-            )
-
-        valid_peer_states = [
-            TrainingState.parse_obj(peer_state.value)
-            for peer_state in response.values()
-            if peer_state.value is not None
-        ]
-
-        num_peers = len(valid_peer_states)
-        num_clients = sum(state.client_mode for state in valid_peer_states)
-        global_optimizer_step = self.local_step
-        for state in valid_peer_states:
-            if not state.client_mode:
-                global_optimizer_step = max(global_optimizer_step, state.step)
-
-        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
-
-        for state in valid_peer_states:
-            total_samples_per_second += state.samples_per_second
-            if state.step == global_optimizer_step:
-                total_samples_accumulated += state.samples_accumulated
-                estimated_current_samples += (
-                    state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
-                )
-            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
-            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
-
-        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
-        estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
-
-        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
-        time_to_next_fetch = float(
-            np.clip(
-                a=estimated_time_to_next_step * num_peers / expected_max_peers,
-                a_min=self.min_refresh_period,
-                a_max=self.max_refresh_period,
-            )
-        )
-        logger.log(
-            self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers for step #{global_optimizer_step}. "
-            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
-        )
-        return CollaborationState(
-            global_optimizer_step,
-            total_samples_accumulated,
-            target_batch_size=self.target_batch_size,
-            num_peers=num_peers,
-            num_clients=num_clients,
-            eta_next_step=current_time + estimated_time_to_next_step,
-            next_fetch_time=current_time + time_to_next_fetch,
-        )
-
-    def zero_grad(self, *args, **kwargs):
-        if self.reuse_grad_buffers:
-            raise ValueError(
-                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
-                f"call zero_grad manually. Gradients will be refreshed internally."
-            )
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def update_scheduler(self):
-        if self.scheduler:
-            while self.scheduler._step_count < self.local_step:
-                self.scheduler.step()
-
-    def shutdown(self):
-        logger.debug("Shutting down averager...")
-        self.averager.shutdown()
-        logger.debug("Sending goodbye to peers...")
-        self.dht.store(
-            self.training_progress_key,
-            subkey=self._local_public_key,
-            value=None,
-            expiration_time=get_dht_time() + self.metadata_expiration,
-        )
-        logger.debug(f"{self.__class__.__name__} is shut down")
-
-    def __del__(self):
-        self.shutdown()

+ 1 - 1
hivemind/optim/grad_scaler.py

@@ -50,7 +50,7 @@ class GradScaler(TorchGradScaler):
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
         with self._lock:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            assert isinstance(optimizer, hivemind.Optimizer)
             if self._is_running_global_step:
                 super().unscale_(optimizer)
                 self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])

+ 0 - 229
hivemind/optim/simple.py

@@ -1,229 +0,0 @@
-import time
-from threading import Event, Lock, Thread
-from typing import Optional, Sequence, Tuple
-
-import torch
-
-from hivemind.dht import DHT
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-
-logger = get_logger(__name__)
-
-
-class DecentralizedOptimizer(DecentralizedOptimizerBase):
-    """
-    A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
-    parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
-
-    :param opt: a pytorch optimizer configured to update model parameters.
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param average_parameters: whether to average model parameters
-    :param average_gradients: whether to average gradients
-    :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
-      many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
-    :param kwargs: additional parameters passed to TrainingAverager
-    :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
-      necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
-    :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        dht: DHT,
-        *,
-        prefix: str,
-        target_group_size: int,
-        average_parameters: bool,
-        average_gradients: bool,
-        average_opt_statistics: Sequence[str] = (),
-        averaging_steps_period: int = 1,
-        averaging_time_period: float = 0,
-        timeout: Optional[float] = None,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-        assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
-        self.local_step, self.averaging_step_period = 0, averaging_steps_period
-
-        self.averager = TrainingAverager(
-            opt,
-            average_parameters=average_parameters,
-            average_gradients=average_gradients,
-            average_opt_statistics=average_opt_statistics,
-            dht=dht,
-            start=True,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            **kwargs,
-        )
-        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-        self.lock_parameters.acquire()  # this lock is only released when averager can modify tensors in background
-
-        self.background_averaging_thread = Thread(
-            name=f"{self.__class__.__name__}",
-            daemon=True,
-            target=self._average_parameters_in_background,
-            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
-            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose),
-        )
-        self.background_averaging_thread.start()
-
-    def step(self, *args, **kwargs):
-        loss = self.opt.step(*args, **kwargs)
-        if self.lock_parameters.locked():
-            self.lock_parameters.release()
-        try:
-            self.local_step += 1
-            if self.local_step % self.averaging_step_period == 0:
-                self.update_event.set()
-            self.averager.pending_updates_done.wait()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = get_dht_time()
-
-            return loss
-        finally:
-            self.lock_parameters.acquire()
-
-    def zero_grad(self, *args, **kwargs):
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def __del__(self):
-        self.stop_event.set()
-        self.update_event.set()
-
-    def shutdown(self):
-        self.stop_event.set()
-        self.update_event.set()
-        self.averager.shutdown()
-
-    @staticmethod
-    @torch.no_grad()
-    def _average_parameters_in_background(
-        lock_parameters: Lock,
-        update_event: Event,
-        stop_event: Event,
-        averager: TrainingAverager,
-        averaging_period: float,
-        verbose: bool,
-        **kwargs,
-    ):
-        """Iteratively find groups of peers, average parameters with these peers and update local model parameters."""
-        while not stop_event.is_set():
-            update_event.wait()
-            update_event.clear()
-            if stop_event.is_set():
-                break
-
-            if averaging_period:
-                current_time = get_dht_time()
-                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
-                time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
-                time.sleep(time_to_nearest_interval)
-
-            if verbose:
-                logger.info(f"Starting a new averaging round with current parameters")
-            try:
-                group_info = averager.step(lock_parameters, **kwargs)
-                if verbose:
-                    if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
-                    else:
-                        logger.warning(f"Averaging round failed: could not find group")
-            except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}")
-
-
-class DecentralizedSGD(DecentralizedOptimizer):
-    """
-    Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
-     Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
-    - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
-     https://arxiv.org/abs/2103.03239
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        momentum: float = 0,
-        dampening: float = 0,
-        weight_decay: float = 0,
-        nesterov: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            **kwargs,
-        )
-
-
-class DecentralizedAdam(DecentralizedOptimizer):
-    """
-    Decentralized Adam/AmsGrad as proposed in [1], [2]
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] On the Convergence of Decentralized Adaptive Gradient Methods
-    - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        averaging_steps_period: int,
-        betas: Tuple[float, float] = (0.9, 0.999),
-        eps: float = 1e-8,
-        weight_decay: float = 0,
-        amsgrad: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
-        opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
-
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=opt_statistics,
-            averaging_steps_period=averaging_steps_period,
-            **kwargs,
-        )

+ 16 - 80
tests/test_averaging.py

@@ -6,7 +6,7 @@ import pytest
 import torch
 
 import hivemind
-import hivemind.averaging.averager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
@@ -78,11 +78,11 @@ def _test_allreduce_once(n_clients, n_aux):
 
     dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             dht=dht,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
@@ -135,11 +135,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
 
     dht_instances = launch_dht_instances(4)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             dht=dht,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             client_mode=client_mode,
             start=True,
@@ -185,7 +185,7 @@ def compute_mean_std(averagers, unbiased=True):
 def test_allreduce_grid():
     dht_instances = launch_dht_instances(8)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
@@ -221,11 +221,11 @@ def test_allreduce_grid():
 def test_allgather(n_averagers=8, target_group_size=4):
     dht_instances = launch_dht_instances(n_averagers)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             [torch.ones(1)],
             dht=dht,
             target_group_size=target_group_size,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             initial_group_bits="000",
             start=True,
@@ -297,11 +297,11 @@ def test_load_balancing():
 def test_too_few_peers():
     dht_instances = launch_dht_instances(4)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
@@ -327,11 +327,11 @@ def test_too_few_peers():
 def test_overcrowded(num_peers=16):
     dht_instances = launch_dht_instances(num_peers)
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             prefix="mygroup",
             initial_group_bits="",
@@ -353,7 +353,7 @@ def test_load_state_from_peers():
     super_metadata = dict(x=123)
     super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
 
-    class TestAverager(hivemind.averaging.DecentralizedAverager):
+    class TestAverager(DecentralizedAverager):
         def get_current_state(self):
             """
             Get current state and send it to a peer. executed in the host process. Meant to be overriden.
@@ -455,7 +455,7 @@ def test_load_state_priority():
 @pytest.mark.forked
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
-    averager = hivemind.averaging.DecentralizedAverager(
+    averager = DecentralizedAverager(
         [torch.randn(3)],
         dht=dht,
         start=True,
@@ -469,7 +469,7 @@ def test_getset_bits():
 @pytest.mark.forked
 def test_averaging_trigger():
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             min_matchmaking_time=0.5,
@@ -514,7 +514,7 @@ def test_averaging_trigger():
 @pytest.mark.forked
 def test_averaging_cancel():
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             min_matchmaking_time=0.5,
@@ -540,67 +540,3 @@ def test_averaging_cancel():
 
     for averager in averagers:
         averager.shutdown()
-
-
-@pytest.mark.forked
-def test_training_averager(n_steps: int = 10, n_dims: int = 16):
-    torch.manual_seed(42)
-
-    dht_instances = launch_dht_instances(2)
-    common_kwargs = {
-        "start": True,
-        "prefix": "demo-run",
-        "target_group_size": 2,
-    }
-
-    x1 = torch.randn(n_dims, requires_grad=True)
-    opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.TrainingAverager(
-        opt1,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[0],
-        **common_kwargs
-    )
-
-    x2 = torch.randn(n_dims, requires_grad=True)
-    opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.TrainingAverager(
-        opt2,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[1],
-        **common_kwargs
-    )
-    a = torch.ones(n_dims)
-
-    for i in range(n_steps):
-        opt1.zero_grad()
-        opt2.zero_grad()
-        (x1 - a).pow(2).sum().backward()
-        (x2 - a).pow(2).sum().backward()
-        opt1.step()
-        opt2.step()
-
-        with torch.no_grad():
-            x_avg = 0.5 * (x1 + x2)
-            grad_avg = 0.5 * (x1.grad + x2.grad)
-            stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
-
-        # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
-        f1 = averager1.step(wait=False)
-        f2 = averager2.step(wait=False)
-        f1.result()
-        f2.result()
-
-        assert torch.allclose(x1, x_avg)
-        assert torch.allclose(x2, x_avg)
-        assert torch.allclose(x1.grad, grad_avg)
-        assert torch.allclose(x2.grad, grad_avg)
-        assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
-        assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
-
-    for instance in [averager1, averager2] + dht_instances:
-        instance.shutdown()

+ 0 - 85
tests/test_training.py

@@ -1,4 +1,3 @@
-import time
 from functools import partial
 
 import pytest
@@ -12,7 +11,6 @@ from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExp
 from hivemind.moe.client.expert import create_remote_experts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 @pytest.mark.forked
@@ -133,86 +131,3 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
 
         assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-@pytest.mark.forked
-def test_decentralized_optimizer_step():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedSGD(
-        [param1],
-        lr=0.1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedSGD(
-        [param2],
-        lr=0.05,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2)
-
-    (param1.sum() + 300 * param2.sum()).backward()
-
-    for i in range(5):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2)
-    reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
-    assert torch.allclose(param1, torch.full_like(param1, reference))
-
-
-@pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
-@pytest.mark.forked
-def test_decentralized_optimizer_averaging():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedAdam(
-        [param1],
-        lr=0.1,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedAdam(
-        [param2],
-        lr=0.05,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    (param1.sum() + param2.sum()).backward()
-
-    for _ in range(100):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)