Kaynağa Gözat

Decentralized adaptive optimizers (#243)

* transforms ParameterAveragingOptimizer into DecentralizedOptimizer which supports many averaging options
* adds preconfigured DADAM
* adds CollaborativeAdaptiveOptimizer which accounts for adaptive learning rates
Roman Zhytar 4 yıl önce
ebeveyn
işleme
e833a7efb9

+ 41 - 15
hivemind/client/averaging/training.py

@@ -1,7 +1,8 @@
 """ An extension of averager that supports common optimization use cases. """
 from itertools import chain
 from threading import Lock
-from typing import Sequence, Dict, Iterator
+from typing import Sequence, Dict, Iterator, Optional
+from contextlib import nullcontext
 
 import torch
 
@@ -30,6 +31,7 @@ class TrainingAverager(DecentralizedAverager):
     :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     """
+
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
                  average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
                  initialize_optimizer: bool = True, **kwargs):
@@ -46,27 +48,51 @@ class TrainingAverager(DecentralizedAverager):
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
 
     @torch.no_grad()
-    def step(self, wait: bool = True, **kwargs):
-        """ Average optimizer weights and gradients with peers. """
+    def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
+        """ Average optimizer weights and gradients with peers.
+        :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
+        modifications occur during averaging step
+        :param wait: if True waits, otherwise returns Future
+        """
         if not wait:
-            return run_in_background(self.step, wait=True, **kwargs)
+            return run_in_background(self.step, data_lock, wait=True, **kwargs)
+
+        # if data_lock is supplied, tensors might change during averaging, so we need to copy them
+        use_old_local_tensors = data_lock is not None
+        if data_lock is None:
+            data_lock = nullcontext()
 
         local_tensors = list(self.local_tensors())
         with self.lock_averager_step:
-            # fill averager's tensors with current local tensors, scaled by peer's weight
-            with self.get_tensors() as averaged_tensors:
-                assert len(local_tensors) == len(averaged_tensors)
+            # fill averager's tensors with current local tensors
+            with data_lock, self.get_tensors() as averaged_tensors:
+                if use_old_local_tensors:
+                    old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
+                assert len(local_tensors) == len(
+                    averaged_tensors), "The number of optimized parameters should not change."
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                    averaged_tensor[...] = local_tensor.detach().cpu().float()
+                    averaged_tensor[...] = local_tensor.cpu().float()
 
-            # find a group and hopefully average tensors with peers
+            # find a group and hopefully average tensors with peers, scaled by peer's weight
             gathered = super().step(**kwargs)
-
-            # load averaged tensors back into model
-            with self.get_tensors() as averaged_tensors:
-                assert len(averaged_tensors) == len(local_tensors)
-                for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                    local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+            if gathered is not None:
+                # load averaged tensors back into model
+                with data_lock, self.get_tensors() as averaged_tensors:
+                    if len(averaged_tensors) != len(local_tensors):
+                        raise RuntimeError("The number of optimized parameters should not change.")
+
+                    if use_old_local_tensors:
+                        # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
+                        # losing local updates that might have occurred during averaging
+                        for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
+                                                                                   old_local_tensors):
+                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
+                                                                    device=local_tensor.device) - \
+                                                 old_local_tensor.to(dtype=local_tensor.dtype,
+                                                                     device=local_tensor.device)
+                    else:
+                        for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
+                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
 
             self.local_step += 1
             return gathered

+ 2 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.optim.simple import DecentralizedSGD, ParameterAveragingOptimizer
+from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer

+ 25 - 0
hivemind/optim/adaptive.py

@@ -0,0 +1,25 @@
+from typing import Sequence
+
+import torch.optim
+
+from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind import TrainingAverager
+
+
+class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
+    """
+    Behaves exactly as CollaborativeOptimizer except:
+     - averages adaptive learning rates of an optimizer
+     - doesn't average gradients
+    :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
+    :param kwargs: options for CollaborativeOptimizer
+    """
+
+    def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
+        super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
+
+    def _make_averager(self, average_opt_statistics, **kwargs):
+        return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=False,
+                                average_opt_statistics=average_opt_statistics,
+                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
+                                listen=not self.client_mode, **kwargs)

+ 73 - 52
hivemind/optim/simple.py

@@ -1,55 +1,67 @@
 import time
 from threading import Thread, Lock, Event
-from typing import Optional
+from typing import Optional, Sequence, Tuple
 
 import torch
 
 from hivemind.dht import DHT
-from hivemind.client.averaging import DecentralizedAverager
+from hivemind.client import TrainingAverager
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time
 
 logger = get_logger(__name__)
 
 
-class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
+class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
-    A simple optimizer that trains a shared model by averaging model parameters with peers in the background.
+    A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
+    parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
 
     :param opt: a pytorch optimizer configured to update model parameters.
     :param dht: a running hivemind DHT daemon connected to other peers
-    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
-      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
+    :param average_parameters: whether to average model parameters
+    :param average_gradients: whether to average gradients
+    :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
+    :param averaging_steps_period: performs averaging after this many optimizer steps
+    :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
+      many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
-    :param kwargs: additional parameters passed to DecentralizedAverager
-
-    :note: using DecentralizedOptimizer with adaptive learning rates may result in poor convergence due to
-      out-of-sync adaptive learning rates (such as adam second momentum or schedule step). Please ensure that these
-      statistics are synchronized or use a more advanced DecentralizedOptimizer version, if applicable.
+    :param kwargs: additional parameters passed to TrainingAverager
+    :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
+      necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
     :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
     """
 
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, prefix: str, target_group_size: int,
-                 averaging_period: float = 0, timeout: Optional[float] = None, verbose: bool = False, **kwargs):
+    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
+                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
+                 averaging_steps_period: int = 1, averaging_time_period: float = 0,
+                 timeout: Optional[float] = None, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
-        with torch.no_grad():
-            averaged_tensors = tuple(p.cpu().float().clone().requires_grad_(False)
-                                     for group in self.param_groups for p in group['params'])
-        self.averager = DecentralizedAverager(averaged_tensors, dht, start=True, prefix=prefix,
-                                              target_group_size=target_group_size, **kwargs)
+        assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
+        self.local_step, self.averaging_step_period = 0, averaging_steps_period
+
+        self.averager = TrainingAverager(opt, average_parameters=average_parameters,
+                                         average_gradients=average_gradients,
+                                         average_opt_statistics=average_opt_statistics,
+                                         dht=dht, start=True, prefix=prefix,
+                                         target_group_size=target_group_size, **kwargs)
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
+
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
-            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager, self.opt],
-            kwargs=dict(averaging_period=averaging_period, timeout=timeout, verbose=verbose))
+            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
+            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
         self.background_averaging_thread.start()
 
     def step(self, *args, **kwargs):
-        self.update_event.set()
         with self.lock_parameters:
-            return self.opt.step(*args, **kwargs)
+            loss = self.opt.step(*args, **kwargs)
+        self.local_step += 1
+        if self.local_step % self.averaging_step_period == 0:
+            self.update_event.set()
+        return loss
 
     def zero_grad(self, *args, **kwargs):
         return self.opt.zero_grad(*args, **kwargs)
@@ -59,13 +71,15 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
         self.update_event.set()
 
     def shutdown(self):
+        self.stop_event.set()
+        self.update_event.set()
         self.averager.shutdown()
 
     @staticmethod
     @torch.no_grad()
     def _average_parameters_in_background(
-            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: DecentralizedAverager,
-            opt: torch.optim.Optimizer, averaging_period: float, verbose: bool, **kwargs):
+            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager,
+            averaging_period: float, verbose: bool, **kwargs):
         """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
         while not stop_event.is_set():
             update_event.wait()
@@ -79,46 +93,27 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
                 time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
                 time.sleep(time_to_nearest_interval)
 
-            with lock_parameters, averager.get_tensors() as averaged_tensors:
-                local_tensors = tuple(p for group in opt.param_groups for p in group['params'])
-                assert len(local_tensors) == len(averaged_tensors), "The number of optimized parameters should not change."
-
-                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                    averaged_tensor[...] = local_tensor.cpu().float()
-
-                old_local_tensors = tuple(local_tensor.cpu().detach().clone() for local_tensor in local_tensors)
-
+            if verbose:
+                logger.info(f"Starting a new averaging round with current parameters.")
             try:
+                group_info = averager.step(lock_parameters, **kwargs)
                 if verbose:
-                    logger.info(f"Starting a new averaging round with current parameters.")
-                group_info = averager.step(**kwargs)
-
-                if group_info is not None:
-                    with lock_parameters, averager.get_tensors() as averaged_tensors:
-                        for local_tensor, old_local_tensor, averaged_tensor in zip(
-                            local_tensors, old_local_tensors, averaged_tensors
-                        ):
-                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype) - old_local_tensor
-                    if verbose:
+                    if group_info is not None:
                         logger.info(f"Finished averaging round in with {len(group_info)} peers.")
-                else:
-                    if verbose:
+                    else:
                         logger.warning(f"Averaging round failed: could not find group.")
             except Exception as e:
                 logger.error(f"Averaging round failed: caught {e}.")
 
 
-class DecentralizedSGD(ParameterAveragingOptimizer):
+class DecentralizedSGD(DecentralizedOptimizer):
     """
     Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
 
-    :param opt: a pytorch optimizer configured to update model parameters.
     :param dht: a running hivemind DHT daemon connected to other peers
-    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
-      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param kwargs: additional parameters passed to DecentralizedAverager
+    :param kwargs: additional parameters passed to DecentralizedOptimizer
 
     - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
      Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
@@ -126,7 +121,33 @@ class DecentralizedSGD(ParameterAveragingOptimizer):
      https://arxiv.org/abs/2103.03239
     """
 
-    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_period: float = 0,
+    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int,
                  momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
         opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(opt, dht, prefix, target_group_size, averaging_period, **kwargs)
+        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
+                         average_gradients=False, **kwargs)
+
+
+class DecentralizedAdam(DecentralizedOptimizer):
+    """
+    Decentralized Adam/AmsGrad as proposed in [1], [2]
+
+    :param dht: a running hivemind DHT daemon connected to other peers
+    :param prefix: all DHT keys that point to optimization metadata will have this prefix
+    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
+    :param averaging_steps_period: performs averaging after this many optimizer steps
+    :param kwargs: additional parameters passed to DecentralizedOptimizer
+
+    - [1] On the Convergence of Decentralized Adaptive Gradient Methods
+    - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
+    """
+
+    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_steps_period: int,
+                 betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0,
+                 amsgrad: bool = False, **kwargs):
+        opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+        opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
+
+        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
+                         average_gradients=False, average_opt_statistics=opt_statistics,
+                         averaging_steps_period=averaging_steps_period, **kwargs)

+ 26 - 1
tests/test_training.py

@@ -8,7 +8,7 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
 from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
-    DecentralizedSGD
+    DecentralizedSGD, DecentralizedAdam
 
 
 @pytest.mark.forked
@@ -135,3 +135,28 @@ def test_decentralized_optimizer_step():
     assert torch.allclose(param1, param2)
     reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
     assert torch.allclose(param1, torch.full_like(param1, reference))
+
+
+@pytest.mark.forked
+def test_decentralized_optimizer_averaging():
+    dht_root = DHT(start=True)
+    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+
+    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
+    opt1 = DecentralizedAdam([param1], lr=0.1, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
+    opt2 = DecentralizedAdam([param2], lr=0.05, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    assert not torch.allclose(param1, param2)
+
+    (param1.sum() + param2.sum()).backward()
+
+    opt1.step()
+    opt2.step()
+
+    time.sleep(0.5)
+    assert torch.allclose(param1, param2)
+    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"])