Эх сурвалжийг харах

Implement CenteredClip in averager

Alexander Borzunov 4 жил өмнө
parent
commit
fee619527f

+ 112 - 0
hivemind/averaging/accumulators.py

@@ -0,0 +1,112 @@
+import dataclasses
+from abc import ABC
+from typing import Callable, Optional
+
+import torch
+
+
+class AccumulatorBase(ABC):
+    def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
+        ...
+
+    def reduce(self) -> torch.Tensor:
+        ...
+
+
+AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]
+
+
+class MeanAccumulator(AccumulatorBase):
+    def __init__(self, part_shape: torch.Size, _n_peers: int):
+        self._accumulator = torch.zeros(part_shape)
+        self._denominator = 0.0
+
+    def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
+        self._accumulator.add_(tensor_part, alpha=weight)
+        self._denominator += weight
+
+    def reduce(self) -> torch.Tensor:
+        return self._accumulator.div_(self._denominator)
+
+
+class CenteredClipAccumulator(AccumulatorBase):
+    def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
+        self._kwargs = kwargs
+
+        self._tensors = torch.empty([n_peers] + part_shape)
+        self._weights = torch.empty(n_peers)
+        self._index = 0
+
+    def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
+        self._tensors[self._index] = tensor_part
+        self._weights[self._index] = weight
+        self._index += 1
+
+    def reduce(self) -> torch.Tensor:
+        clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
+        return clipped.result
+
+
+@dataclasses.dataclass(frozen=True)
+class CenteredClipResult:
+    result: torch.Tensor
+    n_clipped: torch.Tensor
+    last_step_delta: torch.Tensor
+
+
+def centered_clip(
+    input_tensors: torch.Tensor,
+    weights: torch.Tensor,
+    tau: float = 1.0,
+    n_iters: int = 20,
+    stop_delta: Optional[float] = None,
+) -> CenteredClipResult:
+    """
+    Optimized implementation of CenteredClip from [Karimireddy, 2021].
+    Intended to be used in a decentralized fashion as in [Gorbunov, 2021].
+
+    :stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
+                 Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.
+
+    References:
+
+    [Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
+    robust optimization." International Conference on Machine Learning. PMLR, 2021.
+
+    [Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
+    "Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
+    """
+
+    with torch.no_grad():
+        n_peers = input_tensors.shape[0]
+        result_shape = input_tensors.shape[1:]
+
+        input_tensors = input_tensors.flatten(start_dim=1)
+        weights /= weights.sum()
+
+        # This finds medians faster than torch.median() and torch.quantile(q=0.5),
+        # see https://github.com/pytorch/pytorch/issues/51450
+        sorted_tensors = input_tensors.sort(dim=0).values
+        result = sorted_tensors[n_peers // 2].clone()
+        delta = None
+
+        diff = torch.sub(input_tensors, result, out=sorted_tensors)  # Reuse memory from `sorted_tensors`
+        for _ in range(n_iters):
+            norms = diff.norm(dim=1)
+            coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
+
+            if stop_delta is not None:
+                prev_diff = result[...] = diff[0]  # Reuse memory from `result`
+
+            # We only need to update `diff` (not `result`) between iterations
+            diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
+
+            if stop_delta is not None:
+                delta = prev_diff.sub_(diff[0]).max()
+                if delta < stop_delta:
+                    break
+        torch.sub(input_tensors[0], diff[0], out=result)
+
+        return CenteredClipResult(
+            result=result.reshape(result_shape), n_clipped=(tau < norms).sum(), last_step_delta=delta
+        )

+ 4 - 1
hivemind/averaging/allreduce.py

@@ -4,6 +4,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
 import torch
 
+from hivemind.averaging.accumulators import AccumulatorFactory
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
@@ -58,6 +59,7 @@ class AllReduceRunner(ServicerBase):
         tensors: Sequence[torch.Tensor],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
+        accumulator_factory: AccumulatorFactory,
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
@@ -97,7 +99,8 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
-            self.sender_weights,
+            weights=self.sender_weights,
+            accumulator_factory=accumulator_factory,
         )
 
     def __repr__(self):

+ 3 - 0
hivemind/averaging/averager.py

@@ -15,6 +15,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 import numpy as np
 import torch
 
+from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
@@ -112,6 +113,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        accumulator_factory: AccumulatorFactory = MeanAccumulator,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         auxiliary: bool = False,
@@ -170,6 +172,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             compression=compression,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
+            accumulator_factory=accumulator_factory,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce

+ 15 - 9
hivemind/averaging/partition.py

@@ -8,6 +8,7 @@ from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, Type
 import numpy as np
 import torch
 
+from hivemind.averaging.accumulators import AccumulatorFactory
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
@@ -171,16 +172,23 @@ class TensorPartReducer:
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(
+        self,
+        part_shapes: Sequence[torch.Size],
+        num_senders: int,
+        *,
+        weights: Optional[Sequence[float]],
+        accumulator_factory: AccumulatorFactory,
+    ):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.weights = tuple(weights or (1 for _ in range(num_senders)))
         assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
         assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
-        self.accumulator = None  # this will contain the sum of current tensor part from group peers
-        self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
+        self.accumulator_factory = accumulator_factory
+        self.accumulator = None
         self.finished = asyncio.Event()
         self.reset_accumulators()
 
@@ -194,8 +202,7 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
-        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
-        self.denominator = 0.0
+        self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
 
     async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
@@ -211,13 +218,12 @@ class TensorPartReducer:
 
         current_part_future = self.current_part_future
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
         self.current_part_accumulated_from += 1
 
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            current_part_future.set_result(self.accumulator.reduce())
             self.reset_accumulators()
         return await current_part_future
 
@@ -225,7 +231,7 @@ class TensorPartReducer:
         if not self.finished.is_set():
             if hasattr(self, "current_part_future"):
                 self.current_part_future.cancel()
-                del self.accumulator
+                self.accumulator = None
             self.finished.set()
 
     def __del__(self):

+ 3 - 1
tests/test_allreduce.py

@@ -7,6 +7,7 @@ import pytest
 import torch
 
 from hivemind import Quantile8BitQuantization, aenumerate
+from hivemind.averaging.accumulators import MeanAccumulator
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor
@@ -119,7 +120,7 @@ async def test_partitioning_asynchronous():
 @pytest.mark.asyncio
 async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
     tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
-    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)
 
     local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
 
@@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
+            accumulator_factory=MeanAccumulator,
             modes=peer_modes,
             weights=averaging_weights,
             part_size_bytes=part_size_bytes,