Эх сурвалжийг харах

Add GradientAverager with support for delayed averaging (#404)

This PR implements GradientAverager - a subclass of DecentralizedAverager that supports accumulating and aggregating gradients. This class supports pre-scheduling and delayed averaging ( for DPU, #394 ) for use in hivemind.Optimizer ( #400 )

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Aleksandr Borzunov <hxrussia@gmail.com>
justheuristic 3 жил өмнө
parent
commit
8fa0a8e6ae

+ 0 - 0
hivemind/optim/experimental/__init__.py


+ 219 - 0
hivemind/optim/experimental/grad_averager.py

@@ -0,0 +1,219 @@
+import contextlib
+from typing import Iterable, Iterator, Optional
+
+import torch
+
+import hivemind
+from hivemind.averaging import DecentralizedAverager
+from hivemind.averaging.control import StepControl
+from hivemind.utils import DHTExpiration, get_logger
+
+logger = get_logger(__name__)
+
+
+class GradientAverager(DecentralizedAverager):
+    """
+    An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
+    GradientAverager is meant to be used within hivemind.Optimizer, but it can be used standalone (see example below).
+
+    GradientAverager manages three sets of buffers:
+    (1) model gradients - the gradients associated with local model parameters by PyTorch (param.grad).
+        These tensors are typically stored on device and updated by torch autograd
+    (2) gradient accumulators - an [optional] set of buffers where local gradients are accumulated.
+      - note: if reuse_grad_buffers is True, the averager will use gradients from parameters as local accumulators,
+        which reduces RAM usage but requires the user to avoid calling zero_grad / clip_grad manually
+    (3) averaged gradients - gradient buffers that are aggregated in-place with peers, always in host memory
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
+
+
+    Example:
+
+    >>> model = SuchModelMuchLayers()
+    >>> opt = torch.optim.Adam(model.parameters())
+    >>> grad_averager = GradientAverager(model.parameters(), dht=hivemind.DHT(...))
+    >>> next_step_time = hivemind.get_dht_time() + 60   # runs global steps every 60 seconds
+    >>> next_step_control = None
+    >>> while True:
+    >>>    # accumulate as many gradients as you can before next_step_time
+    >>>    loss = compute_loss(model, batch_size=32)
+    >>>    loss.backward()
+    >>>    grad_averager.accumulate_grads_(batch_size=32)
+    >>>    # [optional] next step in 5 seconds, start looking for peers in advance
+    >>>    if next_step_time - hivemind.get_dht_time() <= 5
+    >>>        next_step_control = grad_averager.schedule_step(scheduled_time=next_step_time)
+    >>>    # aggregate gradients and perform optimizer step
+    >>>    if hivemind.get_dht_time() >= next_step_time:
+    >>>        grad_averager.step(control=next_step_control)
+    >>>        with grad_averager.use_averaged_gradients():  # this will fill param.grads with aggregated gradients
+    >>>            opt.step()  # update model parameters using averaged gradients
+    >>>        grad_averager.reset_accumulated_grads_()  # prepare for next step
+    >>>        next_step_time = hivemind.get_dht_time() + 60
+    >>>        next_step_control = None
+
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        *,
+        dht: hivemind.DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        **kwargs,
+    ):
+        if reuse_grad_buffers and accumulate_grads_on is not None:
+            logger.warning("Setting 'accumulate_grads_on' has no effect if reuse_grad_buffers=True")
+        client_mode = client_mode if client_mode is not None else dht.client_mode
+        self._parameters = tuple(parameters)
+        self.reuse_grad_buffers = reuse_grad_buffers
+        self.warn = warn
+        self.local_samples_accumulated = 0
+        self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        self._local_accumulators = None
+        if not reuse_grad_buffers:
+            self._local_accumulators = tuple(
+                torch.zeros_like(grad, device=accumulate_grads_on) for grad in self._grads_from_parameters()
+            )
+        self._accumulators_used_in_step = False
+        self._new_averaged_grads = False
+
+        with torch.no_grad():
+            averaged_grads = tuple(
+                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+            )
+        super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
+
+    def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
+        """gradient buffers associated with parameters"""
+        for param in self._parameters:
+            if param.grad is None:
+                param.grad = torch.zeros_like(param)
+            yield param.grad
+
+    @torch.no_grad()
+    def _grad_accumulators(self) -> Iterator[torch.Tensor]:
+        """averager-based gradient accumulators"""
+        assert (self._local_accumulators is None) == self.reuse_grad_buffers
+        yield from self._grads_from_parameters() if self.reuse_grad_buffers else self._local_accumulators
+
+    @torch.no_grad()
+    def accumulate_grads_(self, batch_size: int):
+        """add current gradients to local grad accumulators (if used)"""
+        if self._accumulators_used_in_step and self.warn:
+            logger.warning(
+                "[warn=True] Gradient accumulators were not reset since the last averaging round. Please "
+                "call .reset_accumulated_grads_ after every step or use .step(reset_accumulators=True)."
+            )
+            self._accumulators_used_in_step = False  # warn once per round
+        if self._anchor_batch_size is None:
+            # remember the first batch size to correctly re-scale gradients if subsequent batches have a different size
+            self._anchor_batch_size = batch_size
+        self.local_samples_accumulated += batch_size
+        self.local_times_accumulated += 1
+        if self.reuse_grad_buffers:
+            pass  # user is responsible for accumulating gradients in .grad buffers
+        else:
+            alpha = float(batch_size) / self._anchor_batch_size
+            for grad_buf, grad_acc in zip(self._grads_from_parameters(), self._grad_accumulators()):
+                grad_acc.add_(grad_buf.to(grad_acc.device), alpha=alpha)
+
+    def schedule_step(self, scheduled_time: Optional[DHTExpiration] = None, **kwargs) -> StepControl:
+        """
+        Begin matchmaking: look for a group of peers and prepare for averaging gradients at a specified time.
+
+        :param scheduled_time: expected time when to perform all-reduce. Can be changed using control.scheduled_time
+        :param kwargs: any additional keyword args from DecentralizedAverager.step, such as gather, allow_retries, etc
+        :note: setting weight at this stage is not supported, please leave this parameter as None
+        :returns: step_control - a handle that can be passed into GradientAverager.step to use the pre-scheduled group
+        :note: in the current implementation, each step_control can only be used in one step.
+        """
+        assert kwargs.get("weight") is None, "setting weight in schedule_step is not supported"
+        return super().step(scheduled_time=scheduled_time, wait=False, require_trigger=True, **kwargs)
+
+    def step(
+        self,
+        weight: Optional[float] = None,
+        reset_accumulators: bool = True,
+        control: Optional[StepControl] = None,
+        wait: bool = True,
+        **kwargs,
+    ):
+        """
+        Average accumulated gradients with peers, optionally load averaged gradients and reset accumulators
+
+        :param weight: overrides the averaging weight; by default, weight equals the number of accumulated samples
+        :param reset_accumulators: by default, set local gradient accumulators to zeros after averaging succeeds
+        :param control: reuse a pre-arranged group of peers (or a matchmaking in progress) from averager.schedule_step
+        :param wait: if True, await for the step to finish (or fail), otherwise run all-reduce in background
+        """
+        if control is None:
+            control = self.schedule_step(**kwargs)
+        elif len(kwargs) > 0:
+            RuntimeError(f"Averaging with a pre-scheduled group, parameters {kwargs} will have no effect.")
+        assert not control.triggered, f"This {type(control)} instance was already used."
+        self._load_accumulators_into_averager_()
+        self._accumulators_used_in_step = True
+        self._new_averaged_grads = True
+
+        control.weight = self.local_samples_accumulated if weight is None else weight
+        if reset_accumulators:
+            self.reset_accumulated_grads_()
+
+        control.allow_allreduce()
+        return control.result() if wait else control
+
+    @torch.no_grad()
+    def _load_accumulators_into_averager_(self):
+        """load locally accumulated gradients into the averager for aggregation"""
+        if self._new_averaged_grads and self.warn:
+            logger.warning(
+                "[warn=True] Starting new averaging round, but previous round results were not used."
+                "This may be a sign of incorrect optimizer behavior."
+            )
+            self._new_averaged_grads = False  # warn once per round
+        # divide locally accumulated gradients by the number of times they were accumulated
+        grad_scale = (1.0 / self.local_times_accumulated) if self.local_times_accumulated != 0 else 0.0
+        with self.get_tensors() as averaged_grads:
+            for grad_acc, averaged_grad in zip(self._grad_accumulators(), averaged_grads):
+                averaged_grad.copy_(grad_acc, non_blocking=True).mul_(grad_scale)
+
+    @torch.no_grad()
+    def reset_accumulated_grads_(self):
+        """reset averager-internal gradient accumulators and the denominator"""
+        self._accumulators_used_in_step = False
+        self.local_samples_accumulated = self.local_times_accumulated = 0
+        self._anchor_batch_size = None
+        for grad_buf in self._grad_accumulators():
+            grad_buf.zero_()
+
+    @contextlib.contextmanager
+    @torch.no_grad()
+    def use_averaged_gradients(self):
+        self._new_averaged_grads = False
+        with self.get_tensors() as averaged_grads:
+            try:
+                assert len(averaged_grads) == len(self._parameters)
+                old_grads = [param.grad for param in self._parameters]
+                for param, new_grad in zip(self._parameters, averaged_grads):
+                    param.grad = new_grad
+                yield
+            finally:
+                for param, old_grad in zip(self._parameters, old_grads):
+                    param.grad = old_grad

+ 68 - 0
tests/test_optimizer.py

@@ -0,0 +1,68 @@
+import time
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import hivemind
+from hivemind.averaging.control import AveragingStage
+from hivemind.optim.experimental.grad_averager import GradientAverager
+
+
+@pytest.mark.forked
+def test_grad_averager():
+    dht1 = hivemind.DHT(start=True)
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager1 = GradientAverager(
+        model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
+    )
+
+    dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
+    averager2 = GradientAverager(
+        model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
+    )
+
+    control1 = averager1.schedule_step(hivemind.get_dht_time() + 5)
+    control2 = averager2.schedule_step(hivemind.get_dht_time() + 5)
+
+    for i in range(10):
+        time.sleep(0.1)
+        if i % 3 == 0:
+            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1.backward()
+            averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
+            model1.zero_grad()
+        else:
+            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2.backward()
+            averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
+            # note: we do not call zero grad here because reuse_grad_buffers=True
+
+    assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
+    peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
+    assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
+    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
+
+    assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
+    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
+
+    averager1.step(control=control1, wait=False)
+    averager2.step(control=control2, wait=False)
+    for step in (control1, control2):
+        step.result()  # wait for all-reduce to finish
+
+    peer1_weight = peer1_samples / (peer1_samples + peer2_samples)
+    peer2_weight = peer2_samples / (peer1_samples + peer2_samples)
+    ref_average = peer1_weight * (ref_grads1 / peer1_times) + peer2_weight * (ref_grads2 / peer2_times)
+    with averager1.use_averaged_gradients():
+        assert torch.allclose(model1.w.grad, ref_average)
+    with averager2.use_averaged_gradients():
+        assert torch.allclose(model2.w.grad, ref_average)
+
+    # after no longer use_averaged_gradients
+    assert not torch.allclose(model1.w.grad, ref_average)
+    assert not torch.allclose(model2.w.grad, ref_average)