Przeglądaj źródła

Decentralized adaptive optimizers (#243)

* transforms ParameterAveragingOptimizer into DecentralizedOptimizer which supports many averaging options
* adds preconfigured DADAM
* adds CollaborativeAdaptiveOptimizer which accounts for adaptive learning rates
Roman Zhytar 4 lat temu
rodzic
commit
e833a7efb9

+ 41 - 15
hivemind/client/averaging/training.py

@@ -1,7 +1,8 @@
 """ An extension of averager that supports common optimization use cases. """
 """ An extension of averager that supports common optimization use cases. """
 from itertools import chain
 from itertools import chain
 from threading import Lock
 from threading import Lock
-from typing import Sequence, Dict, Iterator
+from typing import Sequence, Dict, Iterator, Optional
+from contextlib import nullcontext
 
 
 import torch
 import torch
 
 
@@ -30,6 +31,7 @@ class TrainingAverager(DecentralizedAverager):
     :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
     :note: you can use extra_tensors for averaging tensors that are updated outside of opt.step (e.g. batchnorm stats)
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     :param kwargs: any additional parameters will be forwarded to DecentralizedAverager
     """
     """
+
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
     def __init__(self, opt: torch.optim.Optimizer, *, average_parameters: bool, average_gradients: bool,
                  average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
                  average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[torch.Tensor] = (),
                  initialize_optimizer: bool = True, **kwargs):
                  initialize_optimizer: bool = True, **kwargs):
@@ -46,27 +48,51 @@ class TrainingAverager(DecentralizedAverager):
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
         super().__init__(averaged_tensors=averaged_tensors, **kwargs)
 
 
     @torch.no_grad()
     @torch.no_grad()
-    def step(self, wait: bool = True, **kwargs):
-        """ Average optimizer weights and gradients with peers. """
+    def step(self, data_lock: Optional[Lock] = None, wait: bool = True, **kwargs):
+        """ Average optimizer weights and gradients with peers.
+        :param data_lock: averager locks it when model parameters are modified. Otherwise it's assumed that no model
+        modifications occur during averaging step
+        :param wait: if True waits, otherwise returns Future
+        """
         if not wait:
         if not wait:
-            return run_in_background(self.step, wait=True, **kwargs)
+            return run_in_background(self.step, data_lock, wait=True, **kwargs)
+
+        # if data_lock is supplied, tensors might change during averaging, so we need to copy them
+        use_old_local_tensors = data_lock is not None
+        if data_lock is None:
+            data_lock = nullcontext()
 
 
         local_tensors = list(self.local_tensors())
         local_tensors = list(self.local_tensors())
         with self.lock_averager_step:
         with self.lock_averager_step:
-            # fill averager's tensors with current local tensors, scaled by peer's weight
-            with self.get_tensors() as averaged_tensors:
-                assert len(local_tensors) == len(averaged_tensors)
+            # fill averager's tensors with current local tensors
+            with data_lock, self.get_tensors() as averaged_tensors:
+                if use_old_local_tensors:
+                    old_local_tensors = tuple(x.cpu().float().clone() for x in local_tensors)
+                assert len(local_tensors) == len(
+                    averaged_tensors), "The number of optimized parameters should not change."
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
                 for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                    averaged_tensor[...] = local_tensor.detach().cpu().float()
+                    averaged_tensor[...] = local_tensor.cpu().float()
 
 
-            # find a group and hopefully average tensors with peers
+            # find a group and hopefully average tensors with peers, scaled by peer's weight
             gathered = super().step(**kwargs)
             gathered = super().step(**kwargs)
-
-            # load averaged tensors back into model
-            with self.get_tensors() as averaged_tensors:
-                assert len(averaged_tensors) == len(local_tensors)
-                for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
-                    local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
+            if gathered is not None:
+                # load averaged tensors back into model
+                with data_lock, self.get_tensors() as averaged_tensors:
+                    if len(averaged_tensors) != len(local_tensors):
+                        raise RuntimeError("The number of optimized parameters should not change.")
+
+                    if use_old_local_tensors:
+                        # since tensors might have changed, we subtract old_local_tensor and add averaged. This prevents
+                        # losing local updates that might have occurred during averaging
+                        for averaged_tensor, local_tensor, old_local_tensor in zip(averaged_tensors, local_tensors,
+                                                                                   old_local_tensors):
+                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype,
+                                                                    device=local_tensor.device) - \
+                                                 old_local_tensor.to(dtype=local_tensor.dtype,
+                                                                     device=local_tensor.device)
+                    else:
+                        for averaged_tensor, local_tensor in zip(averaged_tensors, local_tensors):
+                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype, device=local_tensor.device)
 
 
             self.local_step += 1
             self.local_step += 1
             return gathered
             return gathered

+ 2 - 1
hivemind/optim/__init__.py

@@ -1,4 +1,5 @@
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.collaborative import CollaborativeOptimizer
 from hivemind.optim.performance_ema import PerformanceEMA
 from hivemind.optim.performance_ema import PerformanceEMA
-from hivemind.optim.simple import DecentralizedSGD, ParameterAveragingOptimizer
+from hivemind.optim.simple import DecentralizedOptimizer, DecentralizedSGD, DecentralizedAdam
+from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer

+ 25 - 0
hivemind/optim/adaptive.py

@@ -0,0 +1,25 @@
+from typing import Sequence
+
+import torch.optim
+
+from hivemind.optim.collaborative import CollaborativeOptimizer
+from hivemind import TrainingAverager
+
+
+class CollaborativeAdaptiveOptimizer(CollaborativeOptimizer):
+    """
+    Behaves exactly as CollaborativeOptimizer except:
+     - averages adaptive learning rates of an optimizer
+     - doesn't average gradients
+    :param average_opt_statistics: average optimizer statistics with corresponding names in statedict
+    :param kwargs: options for CollaborativeOptimizer
+    """
+
+    def __init__(self, opt: torch.optim.Optimizer, average_opt_statistics: Sequence[str], **kwargs):
+        super().__init__(opt, average_opt_statistics=average_opt_statistics, **kwargs)
+
+    def _make_averager(self, average_opt_statistics, **kwargs):
+        return TrainingAverager(self.opt, dht=self.dht, average_parameters=True, average_gradients=False,
+                                average_opt_statistics=average_opt_statistics,
+                                prefix=f"{self.prefix}_averaging", allreduce_timeout=self.averaging_timeout,
+                                listen=not self.client_mode, **kwargs)

+ 73 - 52
hivemind/optim/simple.py

@@ -1,55 +1,67 @@
 import time
 import time
 from threading import Thread, Lock, Event
 from threading import Thread, Lock, Event
-from typing import Optional
+from typing import Optional, Sequence, Tuple
 
 
 import torch
 import torch
 
 
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.client.averaging import DecentralizedAverager
+from hivemind.client import TrainingAverager
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.optim.base import DecentralizedOptimizerBase
 from hivemind.utils import get_logger, get_dht_time
 from hivemind.utils import get_logger, get_dht_time
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
+class DecentralizedOptimizer(DecentralizedOptimizerBase):
     """
     """
-    A simple optimizer that trains a shared model by averaging model parameters with peers in the background.
+    A simple optimizer that trains a shared model by averaging with peers in variety of ways. Supports
+    parameter/gradient averaging and syncing adaptive learning rates or any other internal statistics of optimizer.
 
 
     :param opt: a pytorch optimizer configured to update model parameters.
     :param opt: a pytorch optimizer configured to update model parameters.
     :param dht: a running hivemind DHT daemon connected to other peers
     :param dht: a running hivemind DHT daemon connected to other peers
-    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
-      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
+    :param average_parameters: whether to average model parameters
+    :param average_gradients: whether to average gradients
+    :param average_opt_statistics: if specified, average optimizer states with corresponding names in state_dict
+    :param averaging_steps_period: performs averaging after this many optimizer steps
+    :param averaging_time_period: if specified, optimizer will attempt to average weights at regular intervals of this
+      many seconds. (averaging step will only occur if the optimizer ran `averaging_steps_period` steps in that interval)
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
     :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
-    :param kwargs: additional parameters passed to DecentralizedAverager
-
-    :note: using DecentralizedOptimizer with adaptive learning rates may result in poor convergence due to
-      out-of-sync adaptive learning rates (such as adam second momentum or schedule step). Please ensure that these
-      statistics are synchronized or use a more advanced DecentralizedOptimizer version, if applicable.
+    :param kwargs: additional parameters passed to TrainingAverager
+    :note: if you're using an optimizer with adaptive learning rates (such as Adam), make sure to specify
+      necessary fields' names in `average_opt_statistics`. Otherwise you may encounter poor convergence.
     :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
     :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
     """
     """
 
 
-    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, prefix: str, target_group_size: int,
-                 averaging_period: float = 0, timeout: Optional[float] = None, verbose: bool = False, **kwargs):
+    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, *, prefix: str, target_group_size: int,
+                 average_parameters: bool, average_gradients: bool, average_opt_statistics: Sequence[str] = (),
+                 averaging_steps_period: int = 1, averaging_time_period: float = 0,
+                 timeout: Optional[float] = None, verbose: bool = False, **kwargs):
         super().__init__(opt, dht)
         super().__init__(opt, dht)
-        with torch.no_grad():
-            averaged_tensors = tuple(p.cpu().float().clone().requires_grad_(False)
-                                     for group in self.param_groups for p in group['params'])
-        self.averager = DecentralizedAverager(averaged_tensors, dht, start=True, prefix=prefix,
-                                              target_group_size=target_group_size, **kwargs)
+        assert averaging_steps_period > 0 and averaging_time_period >= 0, "Averaging period must be positive."
+        self.local_step, self.averaging_step_period = 0, averaging_steps_period
+
+        self.averager = TrainingAverager(opt, average_parameters=average_parameters,
+                                         average_gradients=average_gradients,
+                                         average_opt_statistics=average_opt_statistics,
+                                         dht=dht, start=True, prefix=prefix,
+                                         target_group_size=target_group_size, **kwargs)
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
         self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
+
         self.background_averaging_thread = Thread(
         self.background_averaging_thread = Thread(
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
             name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
-            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager, self.opt],
-            kwargs=dict(averaging_period=averaging_period, timeout=timeout, verbose=verbose))
+            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager],
+            kwargs=dict(averaging_period=averaging_time_period, timeout=timeout, verbose=verbose))
         self.background_averaging_thread.start()
         self.background_averaging_thread.start()
 
 
     def step(self, *args, **kwargs):
     def step(self, *args, **kwargs):
-        self.update_event.set()
         with self.lock_parameters:
         with self.lock_parameters:
-            return self.opt.step(*args, **kwargs)
+            loss = self.opt.step(*args, **kwargs)
+        self.local_step += 1
+        if self.local_step % self.averaging_step_period == 0:
+            self.update_event.set()
+        return loss
 
 
     def zero_grad(self, *args, **kwargs):
     def zero_grad(self, *args, **kwargs):
         return self.opt.zero_grad(*args, **kwargs)
         return self.opt.zero_grad(*args, **kwargs)
@@ -59,13 +71,15 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
         self.update_event.set()
         self.update_event.set()
 
 
     def shutdown(self):
     def shutdown(self):
+        self.stop_event.set()
+        self.update_event.set()
         self.averager.shutdown()
         self.averager.shutdown()
 
 
     @staticmethod
     @staticmethod
     @torch.no_grad()
     @torch.no_grad()
     def _average_parameters_in_background(
     def _average_parameters_in_background(
-            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: DecentralizedAverager,
-            opt: torch.optim.Optimizer, averaging_period: float, verbose: bool, **kwargs):
+            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: TrainingAverager,
+            averaging_period: float, verbose: bool, **kwargs):
         """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
         """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
         while not stop_event.is_set():
         while not stop_event.is_set():
             update_event.wait()
             update_event.wait()
@@ -79,46 +93,27 @@ class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
                 time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
                 time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
                 time.sleep(time_to_nearest_interval)
                 time.sleep(time_to_nearest_interval)
 
 
-            with lock_parameters, averager.get_tensors() as averaged_tensors:
-                local_tensors = tuple(p for group in opt.param_groups for p in group['params'])
-                assert len(local_tensors) == len(averaged_tensors), "The number of optimized parameters should not change."
-
-                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
-                    averaged_tensor[...] = local_tensor.cpu().float()
-
-                old_local_tensors = tuple(local_tensor.cpu().detach().clone() for local_tensor in local_tensors)
-
+            if verbose:
+                logger.info(f"Starting a new averaging round with current parameters.")
             try:
             try:
+                group_info = averager.step(lock_parameters, **kwargs)
                 if verbose:
                 if verbose:
-                    logger.info(f"Starting a new averaging round with current parameters.")
-                group_info = averager.step(**kwargs)
-
-                if group_info is not None:
-                    with lock_parameters, averager.get_tensors() as averaged_tensors:
-                        for local_tensor, old_local_tensor, averaged_tensor in zip(
-                            local_tensors, old_local_tensors, averaged_tensors
-                        ):
-                            local_tensor[...] += averaged_tensor.to(dtype=local_tensor.dtype) - old_local_tensor
-                    if verbose:
+                    if group_info is not None:
                         logger.info(f"Finished averaging round in with {len(group_info)} peers.")
                         logger.info(f"Finished averaging round in with {len(group_info)} peers.")
-                else:
-                    if verbose:
+                    else:
                         logger.warning(f"Averaging round failed: could not find group.")
                         logger.warning(f"Averaging round failed: could not find group.")
             except Exception as e:
             except Exception as e:
                 logger.error(f"Averaging round failed: caught {e}.")
                 logger.error(f"Averaging round failed: caught {e}.")
 
 
 
 
-class DecentralizedSGD(ParameterAveragingOptimizer):
+class DecentralizedSGD(DecentralizedOptimizer):
     """
     """
     Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
     Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
 
 
-    :param opt: a pytorch optimizer configured to update model parameters.
     :param dht: a running hivemind DHT daemon connected to other peers
     :param dht: a running hivemind DHT daemon connected to other peers
-    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
-      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param prefix: all DHT keys that point to optimization metadata will have this prefix
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
     :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
-    :param kwargs: additional parameters passed to DecentralizedAverager
+    :param kwargs: additional parameters passed to DecentralizedOptimizer
 
 
     - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
     - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
      Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
      Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
@@ -126,7 +121,33 @@ class DecentralizedSGD(ParameterAveragingOptimizer):
      https://arxiv.org/abs/2103.03239
      https://arxiv.org/abs/2103.03239
     """
     """
 
 
-    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_period: float = 0,
+    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int,
                  momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
                  momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
         opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
         opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
-        super().__init__(opt, dht, prefix, target_group_size, averaging_period, **kwargs)
+        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
+                         average_gradients=False, **kwargs)
+
+
+class DecentralizedAdam(DecentralizedOptimizer):
+    """
+    Decentralized Adam/AmsGrad as proposed in [1], [2]
+
+    :param dht: a running hivemind DHT daemon connected to other peers
+    :param prefix: all DHT keys that point to optimization metadata will have this prefix
+    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
+    :param averaging_steps_period: performs averaging after this many optimizer steps
+    :param kwargs: additional parameters passed to DecentralizedOptimizer
+
+    - [1] On the Convergence of Decentralized Adaptive Gradient Methods
+    - [2] Toward Communication Efficient Adaptive Gradient Method - https://dl.acm.org/doi/abs/10.1145/3412815.3416891
+    """
+
+    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_steps_period: int,
+                 betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0,
+                 amsgrad: bool = False, **kwargs):
+        opt = torch.optim.Adam(params, lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+        opt_statistics = ("max_exp_avg_sq",) if amsgrad else ("exp_avg_sq",)
+
+        super().__init__(opt, dht, prefix=prefix, target_group_size=target_group_size, average_parameters=True,
+                         average_gradients=False, average_opt_statistics=opt_statistics,
+                         averaging_steps_period=averaging_steps_period, **kwargs)

+ 26 - 1
tests/test_training.py

@@ -8,7 +8,7 @@ import torch.nn.functional as F
 from sklearn.datasets import load_digits
 from sklearn.datasets import load_digits
 
 
 from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
 from hivemind import RemoteExpert, RemoteMixtureOfExperts, RemoteSwitchMixtureOfExperts, background_server, DHT, \
-    DecentralizedSGD
+    DecentralizedSGD, DecentralizedAdam
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
@@ -135,3 +135,28 @@ def test_decentralized_optimizer_step():
     assert torch.allclose(param1, param2)
     assert torch.allclose(param1, param2)
     reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
     reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
     assert torch.allclose(param1, torch.full_like(param1, reference))
     assert torch.allclose(param1, torch.full_like(param1, reference))
+
+
+@pytest.mark.forked
+def test_decentralized_optimizer_averaging():
+    dht_root = DHT(start=True)
+    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+
+    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
+    opt1 = DecentralizedAdam([param1], lr=0.1, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
+    opt2 = DecentralizedAdam([param2], lr=0.05, averaging_steps_period=1, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    assert not torch.allclose(param1, param2)
+
+    (param1.sum() + param2.sum()).backward()
+
+    opt1.step()
+    opt2.step()
+
+    time.sleep(0.5)
+    assert torch.allclose(param1, param2)
+    assert torch.allclose(opt1.state[param1]["exp_avg_sq"], opt2.state[param2]["exp_avg_sq"])