Bläddra i källkod

Implement basic decentralized optimizers (#210)

* Implement basic ParameterAveragingOptimizer
* Implement DecentralizedSGD as a special case
* support averaging period
* add basic optimizer test

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
justheuristic 4 år sedan
förälder
incheckning
ca3aadb8f8

+ 1 - 0
hivemind/client/__init__.py

@@ -1,3 +1,4 @@
 from hivemind.client.expert import RemoteExpert
 from hivemind.client.moe import RemoteMixtureOfExperts
 from hivemind.client.averaging import DecentralizedAverager
+from hivemind.client.optim import ParameterAveragingOptimizer, DecentralizedSGD

+ 3 - 2
hivemind/client/averaging/__init__.py

@@ -6,6 +6,7 @@ import asyncio
 import contextlib
 import ctypes
 import multiprocessing as mp
+import os
 import threading
 import uuid
 import weakref
@@ -196,14 +197,14 @@ class DecentralizedAverager(mp.Process, averaging_pb2_grpc.DecentralizedAveragin
     def shutdown(self) -> None:
         """ Shut down the averager process """
         # TODO notify peers before terminating
-        if self.is_alive():
+        if self._parent_pid != os.getpid() or self.is_alive():
             self._pipe.send(('_SHUTDOWN', None))
             self.terminate()
         else:
             logger.warning("DHT shutdown has no effect: the process is not alive")
 
     def __del__(self):
-        if self.is_alive():
+        if self._parent_pid != os.getpid() or self.is_alive():
             self.shutdown()
 
     def step(self, gather: Optional[DataForGather] = None, allow_retries: bool = True, timeout: Optional[float] = None,

+ 1 - 0
hivemind/client/optim/__init__.py

@@ -0,0 +1 @@
+from hivemind.client.optim.simple import ParameterAveragingOptimizer, DecentralizedSGD

+ 35 - 0
hivemind/client/optim/base.py

@@ -0,0 +1,35 @@
+from typing import Any
+
+import torch
+
+from hivemind.dht import DHT
+
+
+class DecentralizedOptimizerBase(torch.optim.Optimizer):
+    """ A shared interface for all hivemind optimizers. Cooperates with DHT peers to train a shared model """
+    def __init__(self, opt: torch.optim.Optimizer, dht: DHT):
+        self.opt, self.dht = opt, dht
+
+    @property
+    def state(self):
+        return self.opt.state
+
+    @property
+    def param_groups(self):
+        return self.opt.param_groups
+
+    def add_param_group(self, param_group: dict) -> None:
+        raise ValueError(f"{self.__class__.__name__} does not support calling add_param_group after creation."
+                         f"Please provide all parameter groups at init.")
+
+    def state_dict(self) -> dict:
+        return self.opt.state_dict()
+
+    def load_state_dict(self, state_dict: dict):
+        return self.opt.load_state_dict(state_dict)
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}(opt={repr(self.opt)}, dht={repr(self.dht)})"
+
+    def shutdown(self):
+        raise NotImplementedError()

+ 128 - 0
hivemind/client/optim/simple.py

@@ -0,0 +1,128 @@
+import time
+from threading import Thread, Lock, Event
+from typing import Optional
+
+import torch
+
+from hivemind.dht import DHT
+from hivemind.client.averaging import DecentralizedAverager
+from hivemind.client.optim.base import DecentralizedOptimizerBase
+from hivemind.utils import get_logger, get_dht_time
+
+logger = get_logger(__name__)
+
+
+class ParameterAveragingOptimizer(DecentralizedOptimizerBase):
+    """
+    A simple optimizer that trains a shared model by averaging model parameters with peers in the background.
+
+    :param opt: a pytorch optimizer configured to update model parameters.
+    :param dht: a running hivemind DHT daemon connected to other peers
+    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
+      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
+    :param prefix: all DHT keys that point to optimization metadata will have this prefix
+    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
+    :param timeout: if DecentralizedAverager step is unable to form group in this many seconds, cancel step
+    :param kwargs: additional parameters passed to DecentralizedAverager
+
+    :note: using DecentralizedOptimizer with adaptive learning rates may result in poor convergence due to
+      out-of-sync adaptive learning rates (such as adam second momentum or schedule step). Please ensure that these
+      statistics are synchronized or use a more advanced DecentralizedOptimizer version, if applicable.
+    :note: the base optimizer cannot add param groups after the DecentralizedOptimizer is created
+    """
+
+    def __init__(self, opt: torch.optim.Optimizer, dht: DHT, prefix: str, target_group_size: int,
+                 averaging_period: float = 0, timeout: Optional[float] = None, verbose: bool = False, **kwargs):
+        super().__init__(opt, dht)
+        with torch.no_grad():
+            averaged_tensors = tuple(p.cpu().float().clone().requires_grad_(False)
+                                     for group in self.param_groups for p in group['params'])
+        self.averager = DecentralizedAverager(averaged_tensors, dht, start=True, prefix=prefix,
+                                              target_group_size=target_group_size, **kwargs)
+        self.lock_parameters, self.update_event, self.stop_event = Lock(), Event(), Event()
+        self.background_averaging_thread = Thread(
+            name=f'{self.__class__.__name__}', daemon=True, target=self._average_parameters_in_background,
+            args=[self.lock_parameters, self.update_event, self.stop_event, self.averager, self.opt],
+            kwargs=dict(averaging_period=averaging_period, timeout=timeout, verbose=verbose))
+        self.background_averaging_thread.start()
+
+    def step(self, *args, **kwargs):
+        self.update_event.set()
+        with self.lock_parameters:
+            return self.opt.step(*args, **kwargs)
+
+    def zero_grad(self, *args, **kwargs):
+        return self.opt.zero_grad(*args, **kwargs)
+
+    def __del__(self):
+        self.stop_event.set()
+        self.update_event.set()
+
+    def shutdown(self):
+        self.averager.shutdown()
+
+    @staticmethod
+    @torch.no_grad()
+    def _average_parameters_in_background(
+            lock_parameters: Lock, update_event: Event, stop_event: Event, averager: DecentralizedAverager,
+            opt: torch.optim.Optimizer, averaging_period: float, verbose: bool, **kwargs):
+        """ Iteratively find groups of peers, average parameters with these peers and update local model parameters. """
+        while not stop_event.is_set():
+            update_event.wait()
+            update_event.clear()
+            if stop_event.is_set():
+                break
+
+            if averaging_period:
+                current_time = get_dht_time()
+                # note: we use global DHT time to make sure peers start averaging at the ~same time (to form groups)
+                time_to_nearest_interval = max(0.0, averaging_period - current_time % averaging_period)
+                time.sleep(time_to_nearest_interval)
+
+            with lock_parameters, averager.get_tensors() as averaged_tensors:
+                local_tensors = tuple(p for group in opt.param_groups for p in group['params'])
+                assert len(local_tensors) == len(averaged_tensors), "The number of optimized parameters should not change."
+
+                for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                    averaged_tensor[...] = local_tensor.cpu().float()
+
+            try:
+                if verbose:
+                    logger.info(f"Starting a new averaging round with current parameters.")
+                group_info = averager.step(**kwargs)
+
+                if group_info is not None:
+                    with lock_parameters, averager.get_tensors() as averaged_tensors:
+                        for local_tensor, averaged_tensor in zip(local_tensors, averaged_tensors):
+                            local_tensor[...] = averaged_tensor.to(dtype=local_tensor.dtype)
+                    if verbose:
+                        logger.info(f"Finished averaging round in with {len(group_info)} peers.")
+                else:
+                    if verbose:
+                        logger.warning(f"Averaging round failed: could not find group.")
+            except Exception as e:
+                logger.error(f"Averaging round failed: caught {e}.")
+
+
+class DecentralizedSGD(ParameterAveragingOptimizer):
+    """
+    Decentralized Stochastic Gradient Descent algorithm like in Lian et al (2017) [1] based on Moshpit All-Reduce [2].
+
+    :param opt: a pytorch optimizer configured to update model parameters.
+    :param dht: a running hivemind DHT daemon connected to other peers
+    :param averaging_period: if specified, optimizer will attempt to average weights at regular intervals of this many
+      seconds. (averaging step will only occur if the optimizer ran at least one step in that interval)
+    :param prefix: all DHT keys that point to optimization metadata will have this prefix
+    :param target_group_size: maximum group size for averaging (see DecentralizedAverager)
+    :param kwargs: additional parameters passed to DecentralizedAverager
+
+    - [1] Can Decentralized Algorithms Outperform Centralized Algorithms? A Case Study for Parallel Stochastic Gradient
+     Descent - https://proceedings.neurips.cc/paper/2017/hash/f75526659f31040afeb61cb7133e4e6d-Abstract.html
+    - [2] Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices
+     https://arxiv.org/abs/2103.03239
+    """
+
+    def __init__(self, params, lr: float, *, dht: DHT, prefix: str, target_group_size: int, averaging_period: float = 0,
+                 momentum: float = 0, dampening: float = 0, weight_decay: float = 0, nesterov: bool = False, **kwargs):
+        opt = torch.optim.SGD(params, lr, momentum, dampening, weight_decay, nesterov)
+        super().__init__(opt, dht, prefix, target_group_size, averaging_period, **kwargs)

+ 28 - 1
tests/test_training.py

@@ -1,12 +1,13 @@
 from functools import partial
 
+import time
 import pytest
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from sklearn.datasets import load_digits
 
-from hivemind import RemoteExpert, background_server
+from hivemind import RemoteExpert, background_server, DHT, DecentralizedSGD
 
 
 @pytest.mark.forked
@@ -36,3 +37,29 @@ def test_training(max_steps: int = 100, threshold: float = 0.9):
                 break
 
         assert accuracy >= threshold, f"too small accuracy: {accuracy}"
+
+
+@pytest.mark.forked
+def test_decentralized_optimizer_step():
+    dht_root = DHT(start=True)
+    initial_peers = [f"127.0.0.1:{dht_root.port}"]
+
+    param1 = torch.nn.Parameter(torch.zeros(32, 32), requires_grad=True)
+    opt1 = DecentralizedSGD([param1], lr=0.1, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    param2 = torch.nn.Parameter(torch.ones(32, 32), requires_grad=True)
+    opt2 = DecentralizedSGD([param2], lr=0.05, dht=DHT(initial_peers=initial_peers, start=True),
+                            prefix='foo', target_group_size=2, verbose=True)
+
+    assert not torch.allclose(param1, param2)
+
+    (param1.sum() + 300 * param2.sum()).backward()
+
+    opt1.step()
+    opt2.step()
+
+    time.sleep(0.5)
+    assert torch.allclose(param1, param2)
+    reference = 0.5 * (0.0 - 0.1 * 1.0) + 0.5 * (1.0 - 0.05 * 300)
+    assert torch.allclose(param1, torch.full_like(param1, reference))