Browse Source

Remove deprecated code in hivemind.optim and hivemind.averaging before the 1.1.0 release (#480)

* Deprecate averaging_expiration

* Remove DecentralizedOptimizerBase with subclasses

* Update docs
Max Ryabinin 3 years ago
parent
commit
4a3d8fb843

+ 5 - 6
benchmarks/benchmark_averaging.py

@@ -6,10 +6,9 @@ import time
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind.proto import runtime_pb2
+from hivemind.compression import Float16Compression
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from hivemind.utils.networking import LOCALHOST
 
 
 use_hivemind_log_handler("in_root_logger")
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -38,7 +37,7 @@ def benchmark_averaging(
     num_peers: int,
     num_peers: int,
     target_group_size: int,
     target_group_size: int,
     num_rounds: int,
     num_rounds: int,
-    averaging_expiration: float,
+    min_matchmaking_time: float,
     request_timeout: float,
     request_timeout: float,
     round_timeout: float,
     round_timeout: float,
     hid_size: int,
     hid_size: int,
@@ -64,9 +63,9 @@ def benchmark_averaging(
             dht,
             dht,
             prefix="my_tensor",
             prefix="my_tensor",
             initial_group_bits=initial_bits,
             initial_group_bits=initial_bits,
-            compression_type=runtime_pb2.CompressionType.FLOAT16,
+            compression=Float16Compression(),
             target_group_size=target_group_size,
             target_group_size=target_group_size,
-            averaging_expiration=averaging_expiration,
+            min_matchmaking_time=min_matchmaking_time,
             request_timeout=request_timeout,
             request_timeout=request_timeout,
             start=True,
             start=True,
         )
         )
@@ -108,7 +107,7 @@ if __name__ == "__main__":
     parser.add_argument("--num_rounds", type=int, default=5, required=False)
     parser.add_argument("--num_rounds", type=int, default=5, required=False)
     parser.add_argument("--hid_size", type=int, default=256, required=False)
     parser.add_argument("--hid_size", type=int, default=256, required=False)
     parser.add_argument("--num_layers", type=int, default=3, required=False)
     parser.add_argument("--num_layers", type=int, default=3, required=False)
-    parser.add_argument("--averaging_expiration", type=float, default=5, required=False)
+    parser.add_argument("--min_matchmaking_time", type=float, default=5, required=False)
     parser.add_argument("--round_timeout", type=float, default=15, required=False)
     parser.add_argument("--round_timeout", type=float, default=15, required=False)
     parser.add_argument("--request_timeout", type=float, default=1, required=False)
     parser.add_argument("--request_timeout", type=float, default=1, required=False)
     parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)
     parser.add_argument("--spawn_dtime", type=float, default=0.1, required=False)

+ 1 - 17
docs/modules/optim.rst

@@ -21,20 +21,4 @@
 
 
 .. currentmodule:: hivemind.optim.grad_scaler
 .. currentmodule:: hivemind.optim.grad_scaler
 .. autoclass:: GradScaler
 .. autoclass:: GradScaler
-   :member-order: bysource
-
-
-**CollaborativeOptimizer**
---------------------------
-
-
-.. automodule:: hivemind.optim.collaborative
-.. currentmodule:: hivemind.optim
-
-.. autoclass:: CollaborativeOptimizer
-   :members: step
-   :member-order: bysource
-
-.. autoclass:: CollaborativeAdaptiveOptimizer
-   :members:
-   :member-order: bysource
+   :member-order: bysource

+ 3 - 3
examples/albert/README.md

@@ -3,7 +3,7 @@
 This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the
 This tutorial will walk you through the steps to set up collaborative training with the ALBERT-large-v2 model and the
 WikiText103 dataset. It uses Hugging Face [datasets](https://github.com/huggingface/datasets)
 WikiText103 dataset. It uses Hugging Face [datasets](https://github.com/huggingface/datasets)
 and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates,
 and [transformers](https://github.com/huggingface/transformers/) libraries to compute local updates,
-using `hivemind.CollaborativeOptimizer` to exchange information between peers.
+using `hivemind.Optimizer` to exchange information between peers.
 
 
 ## Preparation
 ## Preparation
 
 
@@ -143,7 +143,7 @@ you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a
 
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
 incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example
 incoming connections (e.g. when in colab or behind a firewall), add `--client_mode` to the training script (see example
-below). In case of high network latency, you may want to increase `--averaging_expiration` by a few seconds or
+below). In case of high network latency, you may want to increase `--matchmaking_time` by a few seconds or
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 set `--batch_size_lead` to start averaging a bit earlier than the rest of the collaboration. GPU-wise, each peer should
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
 be able to process one local microbatch each 0.5–1 seconds (see trainer's progress bar). To achieve that, we
 recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
 recommend tuning `--per_device_train_batch_size` and `--gradient_accumulation_steps`.
@@ -182,7 +182,7 @@ Here's an example of a full trainer script for Google Colab:
 !ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
 !ulimit -n 4096 && ./hivemind/examples/albert/run_trainer.py \
     --initial_peers ONE_OR_MORE_PEERS \
     --initial_peers ONE_OR_MORE_PEERS \
     --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
     --logging_dir ./logs --logging_first_step --output_dir ./outputs --overwrite_output_dir \
-    --client_mode --averaging_expiration 10 --batch_size_lead 300 --gradient_accumulation_steps 1
+    --client_mode --matchmaking_time 10 --batch_size_lead 300 --gradient_accumulation_steps 1
 ```
 ```
 
 
 ### Using IPFS
 ### Using IPFS

+ 1 - 11
hivemind/__init__.py

@@ -9,17 +9,7 @@ from hivemind.moe import (
     Server,
     Server,
     register_expert_class,
     register_expert_class,
 )
 )
-from hivemind.optim import (
-    CollaborativeAdaptiveOptimizer,
-    CollaborativeOptimizer,
-    DecentralizedAdam,
-    DecentralizedOptimizer,
-    DecentralizedOptimizerBase,
-    DecentralizedSGD,
-    GradScaler,
-    Optimizer,
-    TrainingAverager,
-)
+from hivemind.optim import GradScaler, Optimizer, TrainingAverager
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
 from hivemind.utils import *
 from hivemind.utils import *
 
 

+ 0 - 6
hivemind/averaging/averager.py

@@ -117,7 +117,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         target_group_size: Optional[int] = None,
         target_group_size: Optional[int] = None,
         min_group_size: int = 2,
         min_group_size: int = 2,
         initial_group_bits: str = "",
         initial_group_bits: str = "",
-        averaging_expiration: Optional[float] = None,
         min_matchmaking_time: float = 5.0,
         min_matchmaking_time: float = 5.0,
         request_timeout: float = 3.0,
         request_timeout: float = 3.0,
         averaging_alpha: float = 1.0,
         averaging_alpha: float = 1.0,
@@ -145,11 +144,6 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         assert all(bit in "01" for bit in initial_group_bits)
         assert all(bit in "01" for bit in initial_group_bits)
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
         assert not client_mode or not auxiliary, "auxiliary peers must accept incoming connections"
 
 
-        if averaging_expiration is not None:
-            logger.warning("averaging_expiration is deprecated and will be removed soon, use min_matchmaking_time")
-            assert min_matchmaking_time == 5.0, "Can't set both averaging_expiration and min_matchmaking_time"
-            min_matchmaking_time = averaging_expiration
-
         super().__init__()
         super().__init__()
         self.dht = dht
         self.dht = dht
         self.prefix = prefix
         self.prefix = prefix

+ 0 - 4
hivemind/optim/__init__.py

@@ -1,7 +1,3 @@
-from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
-from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
 from hivemind.optim.training_averager import TrainingAverager
 from hivemind.optim.training_averager import TrainingAverager

+ 0 - 34
hivemind/optim/adaptive.py

@@ -1,34 +0,0 @@
-from typing import Sequence
-
-import torch.optim
-
-from hivemind.optim.collaborative import CollaborativeOptimizer
-from hivemind.optim.training_averager import TrainingAverager
-
-
-class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
-    """
-    Behaves exactly as CollaborativeOptimizer except:
-
-    * averages adaptive learning rates of an optimizer
-    * doesn't average gradients
-
-    :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
-    :param kwargs: options for CollaborativeOptimizer
-    """
-
-    def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
-        super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
-
-    def _make_averager(self, average_opt_statistics, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=average_opt_statistics,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )

+ 0 - 44
hivemind/optim/base.py

@@ -1,44 +0,0 @@
-from warnings import warn
-
-import torch
-
-from hivemind.dht import DHT
-
-
-class DecentralizedOptimizerBase(torch.optim.Optimizer):
-    """A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model"""
-
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
-        self.opt, self.dht = opt, dht
-        warn(
-            "DecentralizedOptimizerBase and its subclasses have been deprecated and will be removed "
-            "in hivemind 1.1.0. Use hivemind.Optimizer instead",
-            FutureWarning,
-            stacklevel=2,
-        )
-
-    @property
-    def state(self):
-        return self.opt.state
-
-    @property
-    def param_groups(self):
-        return self.opt.param_groups
-
-    def add_param_group(self, param_group: dict) -> None:
-        raise ValueError(
-            f"{self.__class__.__name__} does not support calling add_param_group after creation."
-            f"Please provide all parameter groups at init."
-        )
-
-    def state_dict(self) -> dict:
-        return self.opt.state_dict()
-
-    def load_state_dict(self, state_dict: dict):
-        return self.opt.load_state_dict(state_dict)
-
-    def __repr__(self):
-        return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
-
-    def shutdown(self):
-        raise NotImplementedError()

+ 0 - 558
hivemind/optim/collaborative.py

@@ -1,558 +0,0 @@
-from __future__ import annotations
-
-import logging
-from dataclasses import dataclass
-from threading import Event, Lock, Thread
-from typing import Dict, Iterator, Optional
-
-import numpy as np
-import torch
-from pydantic import BaseModel, StrictBool, StrictFloat, confloat, conint
-
-from hivemind.dht import DHT
-from hivemind.dht.crypto import RSASignatureValidator
-from hivemind.dht.schema import BytesWithPublicKey, SchemaValidator
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.grad_scaler import HivemindGradScaler
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-from hivemind.utils.performance_ema import PerformanceEMA
-
-logger = get_logger(__name__)
-LRSchedulerBase = getattr(torch.optim.lr_scheduler, "_LRScheduler", None)
-
-
-@dataclass(frozen=False)
-class CollaborationState:
-    optimizer_step: int
-    samples_accumulated: int
-    target_batch_size: int
-    num_peers: int
-    num_clients: int
-    eta_next_step: float
-    next_fetch_time: float
-
-    @property
-    def ready_for_step(self):
-        return self.samples_accumulated >= self.target_batch_size or get_dht_time() >= self.eta_next_step
-
-    def register_step(self, local_step: int):
-        self.optimizer_step = max(local_step, self.optimizer_step)
-        self.samples_accumulated = 0
-        self.eta_next_step = float("inf")
-
-
-class TrainingState(BaseModel):
-    peer_id: bytes
-    step: conint(ge=0, strict=True)
-    samples_accumulated: conint(ge=0, strict=True)
-    samples_per_second: confloat(ge=0.0, strict=True)
-    time: StrictFloat
-    client_mode: StrictBool
-
-
-class TrainingProgressSchema(BaseModel):
-    progress: Dict[BytesWithPublicKey, Optional[TrainingState]]
-
-
-class CollaborativeOptimizer(DecentralizedOptimizerBase):
-    """
-    An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers.
-
-    These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
-    Once enough samples were accumulated, optimizers will compute a weighted average of their statistics.
-
-    :note: **For new projects, please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
-      Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and many advanced ones.
-      CollaborativeOptimizer will still be supported for a while, but it will be deprecated in v1.1.0.
-
-    :note: This optimizer behaves unlike regular pytorch optimizers in two ways:
-
-      * calling .step will periodically zero-out gradients w.r.t. model parameters after each step
-      * it may take multiple .step calls without updating model parameters, waiting for peers to accumulate enough samples
-
-
-    :param opt: a standard pytorch optimizer, preferably a large-batch one such as LAMB, LARS, etc.
-    :param dht: a running hivemind.DHT daemon connected to other peers
-    :param prefix: a common prefix for all metadata stored by CollaborativeOptimizer in the DHT
-    :param target_batch_size: perform optimizer step after all peers collectively accumulate this many samples
-    :param batch_size_per_step: before each call to .step, user should accumulate gradients over this many samples
-    :param min_refresh_period: wait for at least this many seconds before fetching new collaboration state
-    :param max_refresh_period: wait for at most this many seconds before fetching new collaboration state
-    :param default_refresh_period: if no peers are detected, attempt to fetch collaboration state this often (seconds)
-    :param expected_drift_peers: assume that this many new peers can join between steps
-    :param expected_drift_rate: assumes that this fraction of current collaboration can join/leave between steps
-    :note: The expected collaboration drift parameters are used to adjust the frequency with which this optimizer will
-      refresh the collaboration-wide statistics (to avoid missing the moment when to run the next step)
-    :param bandwidth: peer's network bandwidth for the purpose of load balancing (recommended: internet speed in mbps)
-    :param step_tolerance: a peer can temporarily be delayed by this many steps without being deemed out of sync
-    :param performance_ema_alpha: smoothing value used to estimate this peer's performance (training samples per second)
-    :param averaging_expiration: peer's requests for averaging will be valid for this many seconds
-    :param metadata_expiration: peer's metadata (e.g. samples processed) is stored onto DHT for this many seconds
-    :param averaging_timeout: if an averaging step hangs for this long, it will be cancelled.
-    :param load_state_timeout: wait for at most this many seconds before giving up on load_state_from_peers
-    :param scheduler: if specified, use this scheduler to update optimizer learning rate
-    :param reuse_grad_buffers: if True, use model's .grad buffers for gradient accumulation.
-      This is more memory efficient, but it requires that the user does *NOT* call model/opt zero_grad at all
-    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
-     device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
-     the cost of extra time per step. If reuse_gradient_accumulators is True, this parameter has no effect.
-    :param client_mode: if True, runs training without incoming connections, in a firewall-compatible mode
-    :param kwargs: additional parameters forwarded to DecentralizedAverager
-    :note: If you are using CollaborativeOptimizer with lr_scheduler, it is recommended to pass this scheduler
-      explicitly into this class. Otherwise, scheduler may not be synchronized between peers.
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_batch_size: int,
-        batch_size_per_step: Optional[int] = None,
-        scheduler: Optional[LRSchedulerBase] = None,
-        min_refresh_period: float = 0.5,
-        max_refresh_period: float = 30,
-        default_refresh_period: float = 3,
-        expected_drift_peers: float = 3,
-        expected_drift_rate: float = 0.2,
-        performance_ema_alpha: float = 0.1,
-        metadata_expiration: float = 60.0,
-        averaging_timeout: Optional[float] = None,
-        load_state_timeout: float = 600.0,
-        step_tolerance: int = 1,
-        reuse_grad_buffers: bool = False,
-        accumulate_grads_on: Optional[torch.device] = None,
-        client_mode: bool = False,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-
-        signature_validator = RSASignatureValidator()
-        self._local_public_key = signature_validator.local_public_key
-        dht.add_validators([SchemaValidator(TrainingProgressSchema, prefix=prefix), signature_validator])
-
-        if reuse_grad_buffers and accumulate_grads_on is not None:
-            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
-        self.prefix, self.scheduler = prefix, scheduler
-        self.target_batch_size, self.batch_size_per_step = target_batch_size, batch_size_per_step
-        self.min_refresh_period, self.max_refresh_period, self.default_refresh_period = (
-            min_refresh_period,
-            max_refresh_period,
-            default_refresh_period,
-        )
-        self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
-        self.averaging_timeout = averaging_timeout
-        self.load_state_timeout = load_state_timeout
-        self.metadata_expiration = metadata_expiration
-        self._grads, self.reuse_grad_buffers, self.accumulate_grads_on = None, reuse_grad_buffers, accumulate_grads_on
-        self.client_mode, self.step_tolerance = client_mode, step_tolerance
-        self.status_loglevel = logging.INFO if verbose else logging.DEBUG
-        self.averager = self._make_averager(**kwargs)
-
-        self._step_supports_amp_scaling = self.reuse_grad_buffers  # enable custom execution with torch GradScaler
-
-        self.training_progress_key = f"{self.prefix}_progress"
-        self.local_samples_accumulated = 0  # a number of local samples accumulated since last optimizer update
-        self.local_updates_accumulated = 0  # a number of calls to step() since last optimizer update
-        self.performance_ema = PerformanceEMA(alpha=performance_ema_alpha)
-        self.last_step_time = None
-
-        self.collaboration_state = self._fetch_state()
-        self.lock_collaboration_state, self.collaboration_state_updated = Lock(), Event()
-        self.lock_local_progress, self.should_report_progress = Lock(), Event()
-        self.progress_reporter = Thread(target=self.report_training_progress, daemon=True, name=f"{self}.reporter")
-        self.progress_reporter.start()
-        self.collaboration_state_updater = Thread(
-            target=self.check_collaboration_state_periodically, daemon=True, name=f"{self}.collaboration_state_updater"
-        )
-        self.collaboration_state_updater.start()
-
-    def _make_averager(self, **kwargs):
-        return TrainingAverager(
-            self.opt,
-            dht=self.dht,
-            average_parameters=True,
-            average_gradients=True,
-            prefix=f"{self.prefix}_averaging",
-            allreduce_timeout=self.averaging_timeout,
-            client_mode=self.client_mode,
-            **kwargs,
-        )
-
-    @property
-    def local_step(self) -> int:
-        return self.averager.local_step
-
-    @property
-    def is_synchronized(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step
-
-    @property
-    def is_within_tolerance(self) -> bool:
-        return self.local_step >= self.collaboration_state.optimizer_step - self.step_tolerance
-
-    def is_alive(self) -> bool:
-        return self.averager.is_alive()
-
-    def load_state_from_peers(self, **kwargs):
-        """Attempt to fetch the newest collaboration state from other peers"""
-        with self.lock_collaboration_state:
-            while True:
-                try:
-                    self.averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
-                    break
-                except KeyboardInterrupt:
-                    raise
-                except BaseException as e:
-                    logger.exception(f"Failed to load state from peers: {e}, retrying ...")
-                    continue
-
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.update_scheduler()
-
-    def state_dict(self) -> dict:
-        state_dict = super().state_dict()
-        state_dict["state"]["collaborative_step"] = self.local_step
-        return state_dict
-
-    def load_state_dict(self, state_dict: dict):
-        if "collaborative_step" in state_dict["state"]:
-            self.averager.local_step = state_dict["state"].pop("collaborative_step")
-        return super().load_state_dict(state_dict)
-
-    def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindGradScaler] = None, **kwargs):
-        """
-        Report accumulating gradients w.r.t. batch_size additional samples, optionally update model parameters
-
-        :param batch_size: optional override for batch_size_per_step from init
-        :param grad_scaler: if amp is enabled, this **must** be a hivemind-aware gradient scaler
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-        if grad_scaler is not None and not isinstance(grad_scaler, HivemindGradScaler):
-            raise ValueError("CollaborativeOptimizer requires a hivemind-aware gradient scaler (HivemindGradScaler)")
-        if self.batch_size_per_step is None:
-            if batch_size is None:
-                raise ValueError("Please either set batch_size_per_step parameter at init or when calling .step")
-            logger.log(self.status_loglevel, f"Setting default batch_size_per_step to {batch_size}")
-            self.batch_size_per_step = batch_size
-        batch_size = batch_size if batch_size is not None else self.batch_size_per_step
-
-        if not self.is_synchronized and not self.is_within_tolerance:
-            logger.log(self.status_loglevel, "Peer is out of sync")
-            self.load_state_from_peers()
-            return
-        elif not self.is_synchronized and self.is_within_tolerance:
-            self.averager.local_step = self.collaboration_state.optimizer_step
-            logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}")
-
-        if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
-            logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
-            self.local_samples_accumulated = self.local_steps_accumulated = 0
-            self.reset_accumulated_grads_()
-            self.should_report_progress.set()
-            return
-
-        if self.last_step_time is not None and get_dht_time() - self.last_step_time > self.metadata_expiration:
-            logger.warning(
-                f"Training step took {get_dht_time() - self.last_step_time}, "
-                f"but metadata expired in {self.metadata_expiration} s."
-            )
-
-        self.accumulate_grads_(batch_size)
-
-        with self.lock_local_progress:
-            self.local_samples_accumulated += batch_size
-            self.local_updates_accumulated += 1
-            self.performance_ema.update(task_size=batch_size)
-            self.should_report_progress.set()
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        with self.performance_ema.pause(), self.lock_collaboration_state:
-            self.collaboration_state = self._fetch_state()
-            self.collaboration_state_updated.set()
-
-            # divide accumulators by local steps to recover the true average grad w.r.t. local_samples_accumulated
-            self.apply_accumulated_grads_(scale_by=1.0 / self.local_updates_accumulated)
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.unscale_(self)
-
-            current_step, group_info = self.averager.local_step, None
-
-            if self.collaboration_state.num_peers > 1:
-                mean_samples_per_worker = self.target_batch_size / self.collaboration_state.num_peers
-                weight = self.local_samples_accumulated / mean_samples_per_worker
-                try:
-                    group_info = self.averager.step(
-                        weight=weight, gather=current_step, timeout=self.averaging_timeout, **kwargs
-                    )
-                    if group_info:
-                        logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                        # update our current step if we averaged with another peer that was at a more recent step
-                        for peer, peer_step in group_info.items():
-                            if isinstance(peer_step, int):
-                                current_step = max(current_step, peer_step)
-                            else:
-                                logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-
-                except BaseException as e:
-                    logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            else:
-                logger.log(
-                    self.status_loglevel,
-                    f"Skipped averaging: collaboration consists of " f"{self.collaboration_state.num_peers} peer(s)",
-                )
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.step(self)
-            else:
-                self.opt.step()
-
-            self.reset_accumulated_grads_()
-            self.local_samples_accumulated = self.local_updates_accumulated = 0
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-            self.update_scheduler()
-
-            if grad_scaler is not None:
-                with grad_scaler.running_global_step():
-                    assert grad_scaler.update()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = self.local_step
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def step_aux(self, **kwargs):
-        """
-        Find and assist other peers in averaging without sending local gradients.
-
-        :note: this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.
-        """
-
-        if not self.collaboration_state.ready_for_step:
-            return
-
-        logger.log(self.status_loglevel, f"Beginning global optimizer step #{self.collaboration_state.optimizer_step}")
-        self.collaboration_state = self._fetch_state()
-        self.collaboration_state_updated.set()
-
-        with self.lock_collaboration_state:
-            current_step, group_info = self.averager.local_step, None
-
-            try:
-                group_info = self.averager.step(timeout=self.averaging_timeout, gather=current_step, **kwargs)
-                if group_info:
-                    logger.log(self.status_loglevel, f"Averaged tensors successfully with {len(group_info)} peers")
-
-                    # update our current step if we averaged with another peer that was at a more recent step
-                    for peer, peer_step in group_info.items():
-                        if isinstance(peer_step, int):
-                            current_step = max(current_step, peer_step)
-                        else:
-                            logger.warning(f"Peer {peer} sent malformed data about current step: {peer_step}")
-            except BaseException as e:
-                logger.log(self.status_loglevel, f"Skipped averaging: averaging round failed with {repr(e)}")
-
-            self.collaboration_state.register_step(current_step + 1)
-            self.averager.local_step = current_step + 1
-            self.collaboration_state_updated.set()
-
-        logger.log(self.status_loglevel, f"Optimizer step: done!")
-
-        return group_info
-
-    def _grad_buffers(self) -> Iterator[torch.Tensor]:
-        """pytorch-internal gradient buffers"""
-        for param_group in self.opt.param_groups:
-            for param in param_group["params"]:
-                if param.grad is None:
-                    yield torch.zeros_like(param)
-                else:
-                    yield param.grad
-
-    @torch.no_grad()
-    def accumulated_grads(self) -> Iterator[torch.Tensor]:
-        """local gradient accumulators"""
-        if self.reuse_grad_buffers:
-            yield from self._grad_buffers()
-            return
-
-        if self._grads is None:
-            self._grads = [torch.zeros_like(grad, device=self.accumulate_grads_on) for grad in self._grad_buffers()]
-        yield from self._grads
-
-    @torch.no_grad()
-    def accumulate_grads_(self, batch_size: int):
-        """add current gradients to grad accumulators (if any)"""
-        if self.reuse_grad_buffers:
-            # user is responsible for accumulating gradients in .grad buffers
-            assert batch_size == self.batch_size_per_step, "Custom batch size is not supported if reuse_grad_buffers"
-        else:
-            alpha = float(batch_size) / self.batch_size_per_step
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
-
-    @torch.no_grad()
-    def apply_accumulated_grads_(self, scale_by: Optional[float] = None):
-        if not self.reuse_grad_buffers:
-            for grad_buf, grad_acc in zip(self._grad_buffers(), self.accumulated_grads()):
-                grad_buf.copy_(grad_acc.to(grad_buf.device), non_blocking=True)
-        if scale_by is not None:
-            for grad_buf in self._grad_buffers():
-                grad_buf.mul_(scale_by)
-
-    @torch.no_grad()
-    def reset_accumulated_grads_(self):
-        for grad_buf in self.accumulated_grads():
-            grad_buf.zero_()
-
-    def report_training_progress(self):
-        """Periodically publish metadata and the current number of samples accumulated towards the next step"""
-        while self.is_alive():
-            self.should_report_progress.wait()
-            self.should_report_progress.clear()
-            with self.lock_local_progress:
-                current_time = get_dht_time()
-                local_state_info = TrainingState(
-                    peer_id=self.averager.peer_id.to_bytes(),
-                    step=self.local_step,
-                    samples_accumulated=self.local_samples_accumulated,
-                    samples_per_second=self.performance_ema.samples_per_second,
-                    time=current_time,
-                    client_mode=self.averager.client_mode,
-                )
-
-            self.dht.store(
-                key=self.training_progress_key,
-                subkey=self._local_public_key,
-                value=local_state_info.dict(),
-                expiration_time=current_time + self.metadata_expiration,
-                return_future=True,
-            )
-
-    def check_collaboration_state_periodically(self):
-        """
-        Periodically check the training progress from all peers. Trigger update after target_batch_size total samples
-        """
-        while self.is_alive():
-            time_to_next_update = max(0.0, self.collaboration_state.next_fetch_time - get_dht_time())
-            if self.collaboration_state_updated.wait(time_to_next_update):
-                self.collaboration_state_updated.clear()
-                continue  # if state was updated externally, reset timer
-
-            with self.lock_collaboration_state:
-                self.collaboration_state = self._fetch_state()
-
-    def _fetch_state(self) -> CollaborationState:
-        """Read performance statistics reported by peers, estimate progress towards next batch"""
-        response, _expiration = self.dht.get(self.training_progress_key, latest=True) or (None, -float("inf"))
-        current_time = get_dht_time()
-
-        if not isinstance(response, dict) or len(response) == 0:
-            logger.log(self.status_loglevel, f"Found no active peers: {response}")
-            samples_left_to_target_batch_size = max(0, self.target_batch_size - self.local_samples_accumulated)
-            local_eta_next_step = samples_left_to_target_batch_size / self.performance_ema.samples_per_second
-
-            return CollaborationState(
-                self.local_step,
-                self.local_samples_accumulated,
-                self.target_batch_size,
-                num_peers=0,
-                num_clients=0,
-                eta_next_step=current_time + local_eta_next_step,
-                next_fetch_time=current_time + self.default_refresh_period,
-            )
-
-        valid_peer_states = [
-            TrainingState.parse_obj(peer_state.value)
-            for peer_state in response.values()
-            if peer_state.value is not None
-        ]
-
-        num_peers = len(valid_peer_states)
-        num_clients = sum(state.client_mode for state in valid_peer_states)
-        global_optimizer_step = self.local_step
-        for state in valid_peer_states:
-            if not state.client_mode:
-                global_optimizer_step = max(global_optimizer_step, state.step)
-
-        total_samples_accumulated = estimated_current_samples = total_samples_per_second = 0
-
-        for state in valid_peer_states:
-            total_samples_per_second += state.samples_per_second
-            if state.step == global_optimizer_step:
-                total_samples_accumulated += state.samples_accumulated
-                estimated_current_samples += (
-                    state.samples_accumulated + max(0, current_time - state.time) * state.samples_per_second
-                )
-            # note: we deliberately count only valid peers for samples_accumulated, but all peers for performance;
-            # the rationale behind this is that outdated peers will synchronize and begin contributing shortly.
-
-        estimated_samples_remaining = self.target_batch_size - estimated_current_samples
-        estimated_time_to_next_step = max(0, estimated_samples_remaining) / total_samples_per_second
-
-        expected_max_peers = max(num_peers + self.expected_drift_peers, num_peers * (1 + self.expected_drift_rate))
-        time_to_next_fetch = float(
-            np.clip(
-                a=estimated_time_to_next_step * num_peers / expected_max_peers,
-                a_min=self.min_refresh_period,
-                a_max=self.max_refresh_period,
-            )
-        )
-        logger.log(
-            self.status_loglevel,
-            f"{self.prefix} accumulated {total_samples_accumulated} samples from "
-            f"{num_peers} peers for step #{global_optimizer_step}. "
-            f"ETA {estimated_time_to_next_step:.2f} sec (refresh in {time_to_next_fetch:.2f} sec)",
-        )
-        return CollaborationState(
-            global_optimizer_step,
-            total_samples_accumulated,
-            target_batch_size=self.target_batch_size,
-            num_peers=num_peers,
-            num_clients=num_clients,
-            eta_next_step=current_time + estimated_time_to_next_step,
-            next_fetch_time=current_time + time_to_next_fetch,
-        )
-
-    def zero_grad(self, *args, **kwargs):
-        if self.reuse_grad_buffers:
-            raise ValueError(
-                f"When running {self.__class__.__name__} with reuse_grad_buffers=True, user should never "
-                f"call zero_grad manually. Gradients will be refreshed internally."
-            )
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def update_scheduler(self):
-        if self.scheduler:
-            while self.scheduler._step_count < self.local_step:
-                self.scheduler.step()
-
-    def shutdown(self):
-        logger.debug("Shutting down averager...")
-        self.averager.shutdown()
-        logger.debug("Sending goodbye to peers...")
-        self.dht.store(
-            self.training_progress_key,
-            subkey=self._local_public_key,
-            value=None,
-            expiration_time=get_dht_time() + self.metadata_expiration,
-        )
-        logger.debug(f"{self.__class__.__name__} is shut down")
-
-    def __del__(self):
-        self.shutdown()

+ 1 - 1
hivemind/optim/grad_scaler.py

@@ -50,7 +50,7 @@ class GradScaler(TorchGradScaler):
 
 
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
     def unscale_(self, optimizer: TorchOptimizer) -> bool:
         with self._lock:
         with self._lock:
-            assert isinstance(optimizer, (hivemind.Optimizer, hivemind.DecentralizedOptimizerBase))
+            assert isinstance(optimizer, hivemind.Optimizer)
             if self._is_running_global_step:
             if self._is_running_global_step:
                 super().unscale_(optimizer)
                 super().unscale_(optimizer)
                 self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])
                 self._inner_optimizer_states[id(optimizer.opt)] = deepcopy(self._per_optimizer_states[id(optimizer)])

+ 0 - 229
hivemind/optim/simple.py

@@ -1,229 +0,0 @@
-import time
-from threading import Event, Lock, Thread
-from typing import Optional, Sequence, Tuple
-
-import torch
-
-from hivemind.dht import DHT
-from hivemind.optim.base import DecentralizedOptimizerBase
-from hivemind.optim.training_averager import TrainingAverager
-from hivemind.utils import get_dht_time, get_logger
-
-logger = get_logger(__name__)
-
-
-class DecentralizedOptimizer(DecentralizedOptimizerBase):
-    """
-    A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
-    parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
-
-    :param opt: a pytorch optimizer configured to update model parameters.
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param average_parameters: whether to average model parameters
-    :param average_gradients: whether to average gradients
-    :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
-      many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
-    :param kwargs: additional parameters passed to TrainingAverager
-    :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
-      necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
-    :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
-    """
-
-    def __init__(
-        self,
-        opt: torch.optim.Optimizer,
-        dht: DHT,
-        *,
-        prefix: str,
-        target_group_size: int,
-        average_parameters: bool,
-        average_gradients: bool,
-        average_opt_statistics: Sequence[str] = (),
-        averaging_steps_period: int = 1,
-        averaging_time_period: float = 0,
-        timeout: Optional[float] = None,
-        verbose: bool = False,
-        **kwargs,
-    ):
-        super().__init__(opt, dht)
-        assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
-        self.local_step, self.averaging_step_period = 0, averaging_steps_period
-
-        self.averager = TrainingAverager(
-            opt,
-            average_parameters=average_parameters,
-            average_gradients=average_gradients,
-            average_opt_statistics=average_opt_statistics,
-            dht=dht,
-            start=True,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            **kwargs,
-        )
-        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
-        self.lock_parameters.acquire()  # this lock is only released when averager can modify tensors in background
-
-        self.background_averaging_thread = Thread(
-            name=f"{self.__class__.__name__}",
-            daemon=True,
-            target=self._average_parameters_in_background,
-            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
-            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose),
-        )
-        self.background_averaging_thread.start()
-
-    def step(self, *args, **kwargs):
-        loss = self.opt.step(*args, **kwargs)
-        if self.lock_parameters.locked():
-            self.lock_parameters.release()
-        try:
-            self.local_step += 1
-            if self.local_step % self.averaging_step_period == 0:
-                self.update_event.set()
-            self.averager.pending_updates_done.wait()
-
-            if not self.averager.client_mode:
-                self.averager.state_sharing_priority = get_dht_time()
-
-            return loss
-        finally:
-            self.lock_parameters.acquire()
-
-    def zero_grad(self, *args, **kwargs):
-        return self.opt.zero_grad(*args, **kwargs)
-
-    def __del__(self):
-        self.stop_event.set()
-        self.update_event.set()
-
-    def shutdown(self):
-        self.stop_event.set()
-        self.update_event.set()
-        self.averager.shutdown()
-
-    @staticmethod
-    @torch.no_grad()
-    def _average_parameters_in_background(
-        lock_parameters: Lock,
-        update_event: Event,
-        stop_event: Event,
-        averager: TrainingAverager,
-        averaging_period: float,
-        verbose: bool,
-        **kwargs,
-    ):
-        """Iteratively find groups of peers, average parameters with these peers and update local model parameters."""
-        while not stop_event.is_set():
-            update_event.wait()
-            update_event.clear()
-            if stop_event.is_set():
-                break
-
-            if averaging_period:
-                current_time = get_dht_time()
-                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
-                time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
-                time.sleep(time_to_nearest_interval)
-
-            if verbose:
-                logger.info(f"Starting a new averaging round with current parameters")
-            try:
-                group_info = averager.step(lock_parameters, **kwargs)
-                if verbose:
-                    if group_info is not None:
-                        logger.info(f"Finished averaging round in with {len(group_info)} peers")
-                    else:
-                        logger.warning(f"Averaging round failed: could not find group")
-            except Exception as e:
-                logger.error(f"Averaging round failed: caught {e}")
-
-
-class DecentralizedSGD(DecentralizedOptimizer):
-    """
-    Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
-     Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
-    - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
-     https://arxiv.org/abs/2103.03239
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        momentum: float = 0,
-        dampening: float = 0,
-        weight_decay: float = 0,
-        nesterov: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            **kwargs,
-        )
-
-
-class DecentralizedAdam(DecentralizedOptimizer):
-    """
-    Decentralized Adam/AmsGrad as proposed in [1], [2]
-
-    :param dht: a running hivemind DHT daemon connected to other peers
-    :param prefix: all DHT keys that point to optimization metadata will have this prefix
-    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param averaging_steps_period: performs averaging after this many optimizer steps
-    :param kwargs: additional parameters passed to DecentralizedOptimizer
-
-    - [1] On the Convergence of Decentralized Adaptive Gradient Methods
-    - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
-    """
-
-    def __init__(
-        self,
-        params,
-        lr: float,
-        *,
-        dht: DHT,
-        prefix: str,
-        target_group_size: int,
-        averaging_steps_period: int,
-        betas: Tuple[float, float] = (0.9, 0.999),
-        eps: float = 1e-8,
-        weight_decay: float = 0,
-        amsgrad: bool = False,
-        **kwargs,
-    ):
-        opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
-        opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
-
-        super().__init__(
-            opt,
-            dht,
-            prefix=prefix,
-            target_group_size=target_group_size,
-            average_parameters=True,
-            average_gradients=False,
-            average_opt_statistics=opt_statistics,
-            averaging_steps_period=averaging_steps_period,
-            **kwargs,
-        )

+ 16 - 80
tests/test_averaging.py

@@ -6,7 +6,7 @@ import pytest
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-import hivemind.averaging.averager
+from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.allreduce import AveragingMode
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.key_manager import GroupKeyManager
 from hivemind.averaging.key_manager import GroupKeyManager
@@ -78,11 +78,11 @@ def _test_allreduce_once(n_clients, n_aux):
 
 
     dht_instances = launch_dht_instances(len(peer_tensors))
     dht_instances = launch_dht_instances(len(peer_tensors))
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             tensors,
             dht=dht,
             dht=dht,
             target_group_size=4,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             prefix="mygroup",
             client_mode=mode == AveragingMode.CLIENT,
             client_mode=mode == AveragingMode.CLIENT,
             auxiliary=mode == AveragingMode.AUX,
             auxiliary=mode == AveragingMode.AUX,
@@ -135,11 +135,11 @@ def test_allreduce_weighted(n_client_mode_peers: int = 2):
 
 
     dht_instances = launch_dht_instances(4)
     dht_instances = launch_dht_instances(4)
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             tensors,
             tensors,
             dht=dht,
             dht=dht,
             target_group_size=4,
             target_group_size=4,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             prefix="mygroup",
             client_mode=client_mode,
             client_mode=client_mode,
             start=True,
             start=True,
@@ -185,7 +185,7 @@ def compute_mean_std(averagers, unbiased=True):
 def test_allreduce_grid():
 def test_allreduce_grid():
     dht_instances = launch_dht_instances(8)
     dht_instances = launch_dht_instances(8)
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             dht=dht,
             target_group_size=2,
             target_group_size=2,
@@ -221,11 +221,11 @@ def test_allreduce_grid():
 def test_allgather(n_averagers=8, target_group_size=4):
 def test_allgather(n_averagers=8, target_group_size=4):
     dht_instances = launch_dht_instances(n_averagers)
     dht_instances = launch_dht_instances(n_averagers)
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             [torch.ones(1)],
             [torch.ones(1)],
             dht=dht,
             dht=dht,
             target_group_size=target_group_size,
             target_group_size=target_group_size,
-            averaging_expiration=15,
+            min_matchmaking_time=15,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits="000",
             initial_group_bits="000",
             start=True,
             start=True,
@@ -297,11 +297,11 @@ def test_load_balancing():
 def test_too_few_peers():
 def test_too_few_peers():
     dht_instances = launch_dht_instances(4)
     dht_instances = launch_dht_instances(4)
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             dht=dht,
             target_group_size=2,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             request_timeout=0.5,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
             initial_group_bits=bin(i)[2:].rjust(3, "0"),
@@ -327,11 +327,11 @@ def test_too_few_peers():
 def test_overcrowded(num_peers=16):
 def test_overcrowded(num_peers=16):
     dht_instances = launch_dht_instances(num_peers)
     dht_instances = launch_dht_instances(num_peers)
     averagers = [
     averagers = [
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             dht=dht,
             target_group_size=2,
             target_group_size=2,
-            averaging_expiration=1,
+            min_matchmaking_time=1,
             request_timeout=0.5,
             request_timeout=0.5,
             prefix="mygroup",
             prefix="mygroup",
             initial_group_bits="",
             initial_group_bits="",
@@ -353,7 +353,7 @@ def test_load_state_from_peers():
     super_metadata = dict(x=123)
     super_metadata = dict(x=123)
     super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
     super_tensors = (torch.randn(3), torch.randint(0, 5, (3,)))
 
 
-    class TestAverager(hivemind.averaging.DecentralizedAverager):
+    class TestAverager(DecentralizedAverager):
         def get_current_state(self):
         def get_current_state(self):
             """
             """
             Get current state and send it to a peer. executed in the host process. Meant to be overriden.
             Get current state and send it to a peer. executed in the host process. Meant to be overriden.
@@ -455,7 +455,7 @@ def test_load_state_priority():
 @pytest.mark.forked
 @pytest.mark.forked
 def test_getset_bits():
 def test_getset_bits():
     dht = hivemind.DHT(start=True)
     dht = hivemind.DHT(start=True)
-    averager = hivemind.averaging.DecentralizedAverager(
+    averager = DecentralizedAverager(
         [torch.randn(3)],
         [torch.randn(3)],
         dht=dht,
         dht=dht,
         start=True,
         start=True,
@@ -469,7 +469,7 @@ def test_getset_bits():
 @pytest.mark.forked
 @pytest.mark.forked
 def test_averaging_trigger():
 def test_averaging_trigger():
     averagers = tuple(
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             dht=dht,
             min_matchmaking_time=0.5,
             min_matchmaking_time=0.5,
@@ -514,7 +514,7 @@ def test_averaging_trigger():
 @pytest.mark.forked
 @pytest.mark.forked
 def test_averaging_cancel():
 def test_averaging_cancel():
     averagers = tuple(
     averagers = tuple(
-        hivemind.averaging.DecentralizedAverager(
+        DecentralizedAverager(
             averaged_tensors=[torch.randn(3)],
             averaged_tensors=[torch.randn(3)],
             dht=dht,
             dht=dht,
             min_matchmaking_time=0.5,
             min_matchmaking_time=0.5,
@@ -540,67 +540,3 @@ def test_averaging_cancel():
 
 
     for averager in averagers:
     for averager in averagers:
         averager.shutdown()
         averager.shutdown()
-
-
-@pytest.mark.forked
-def test_training_averager(n_steps: int = 10, n_dims: int = 16):
-    torch.manual_seed(42)
-
-    dht_instances = launch_dht_instances(2)
-    common_kwargs = {
-        "start": True,
-        "prefix": "demo-run",
-        "target_group_size": 2,
-    }
-
-    x1 = torch.randn(n_dims, requires_grad=True)
-    opt1 = torch.optim.Adam([x1], lr=0.05)
-    averager1 = hivemind.TrainingAverager(
-        opt1,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[0],
-        **common_kwargs
-    )
-
-    x2 = torch.randn(n_dims, requires_grad=True)
-    opt2 = torch.optim.Adam([x2], lr=0.05)
-    averager2 = hivemind.TrainingAverager(
-        opt2,
-        average_gradients=True,
-        average_parameters=True,
-        average_opt_statistics=["exp_avg_sq"],
-        dht=dht_instances[1],
-        **common_kwargs
-    )
-    a = torch.ones(n_dims)
-
-    for i in range(n_steps):
-        opt1.zero_grad()
-        opt2.zero_grad()
-        (x1 - a).pow(2).sum().backward()
-        (x2 - a).pow(2).sum().backward()
-        opt1.step()
-        opt2.step()
-
-        with torch.no_grad():
-            x_avg = 0.5 * (x1 + x2)
-            grad_avg = 0.5 * (x1.grad + x2.grad)
-            stats_avg = 0.5 * (opt1.state[x1]["exp_avg_sq"] + opt2.state[x2]["exp_avg_sq"])
-
-        # we set wait=False in order to prevent deadlock, when averager1 locks and waits for averager2
-        f1 = averager1.step(wait=False)
-        f2 = averager2.step(wait=False)
-        f1.result()
-        f2.result()
-
-        assert torch.allclose(x1, x_avg)
-        assert torch.allclose(x2, x_avg)
-        assert torch.allclose(x1.grad, grad_avg)
-        assert torch.allclose(x2.grad, grad_avg)
-        assert torch.allclose(opt1.state[x1]["exp_avg_sq"], stats_avg)
-        assert torch.allclose(opt2.state[x2]["exp_avg_sq"], stats_avg)
-
-    for instance in [averager1, averager2] + dht_instances:
-        instance.shutdown()

+ 0 - 85
tests/test_training.py

@@ -1,4 +1,3 @@
-import time
 from functools import partial
 from functools import partial
 
 
 import pytest
 import pytest
@@ -12,7 +11,6 @@ from hivemind.moe.client import RemoteMixtureOfExperts, RemoteSwitchMixtureOfExp
 from hivemind.moe.client.expert import create_remote_experts
 from hivemind.moe.client.expert import create_remote_experts
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.expert_uid import ExpertInfo
 from hivemind.moe.server import background_server
 from hivemind.moe.server import background_server
-from hivemind.optim import DecentralizedAdam, DecentralizedSGD
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -133,86 +131,3 @@ def test_switch_training(max_steps: int = 10, threshold: float = 0.9, num_expert
 
 
         assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert model.moe.grid_utilization.min().item() > (1 / num_experts) / 2
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
-
-
-@pytest.mark.forked
-def test_decentralized_optimizer_step():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedSGD(
-        [param1],
-        lr=0.1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedSGD(
-        [param2],
-        lr=0.05,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2)
-
-    (param1.sum() + 300 * param2.sum()).backward()
-
-    for i in range(5):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2)
-    reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
-    assert torch.allclose(param1, torch.full_like(param1, reference))
-
-
-@pytest.mark.skip(reason="Skipped until a more stable averager implementation is ready (TODO @justheuristic)")
-@pytest.mark.forked
-def test_decentralized_optimizer_averaging():
-    dht_root = DHT(start=True)
-    initial_peers = dht_root.get_visible_maddrs()
-
-    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
-    opt1 = DecentralizedAdam(
-        [param1],
-        lr=0.1,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
-    opt2 = DecentralizedAdam(
-        [param2],
-        lr=0.05,
-        averaging_steps_period=1,
-        dht=DHT(initial_peers=initial_peers, start=True),
-        prefix="foo",
-        target_group_size=2,
-        verbose=True,
-    )
-
-    assert not torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    (param1.sum() + param2.sum()).backward()
-
-    for _ in range(100):
-        time.sleep(0.1)
-        opt1.step()
-        opt2.step()
-        opt1.zero_grad()
-        opt2.zero_grad()
-
-    assert torch.allclose(param1, param2, atol=1e-3, rtol=0)
-    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"], atol=1e-3, rtol=0)