ソースを参照

Implement CenteredClip in averager

Alexander Borzunov 4 年 前
コミット
fee619527f

+ 112 - 0
hivemind/averaging/accumulators.py

@@ -0,0 +1,112 @@
+import dataclasses
+from abc import ABC
+from typing import Callable, Optional
+
+import torch
+
+
+class AccumulatorBase(ABC):
+    def accumulate_part(self, tensor: torch.Tensor, weight: float) -> None:
+        ...
+
+    def reduce(self) -> torch.Tensor:
+        ...
+
+
+AccumulatorFactory = Callable[[torch.Size, int], AccumulatorBase]
+
+
+class MeanAccumulator(AccumulatorBase):
+    def __init__(self, part_shape: torch.Size, _n_peers: int):
+        self._accumulator = torch.zeros(part_shape)
+        self._denominator = 0.0
+
+    def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
+        self._accumulator.add_(tensor_part, alpha=weight)
+        self._denominator += weight
+
+    def reduce(self) -> torch.Tensor:
+        return self._accumulator.div_(self._denominator)
+
+
+class CenteredClipAccumulator(AccumulatorBase):
+    def __init__(self, part_shape: torch.Size, n_peers: int, **kwargs):
+        self._kwargs = kwargs
+
+        self._tensors = torch.empty([n_peers] + part_shape)
+        self._weights = torch.empty(n_peers)
+        self._index = 0
+
+    def accumulate_part(self, tensor_part: torch.Tensor, weight: float) -> None:
+        self._tensors[self._index] = tensor_part
+        self._weights[self._index] = weight
+        self._index += 1
+
+    def reduce(self) -> torch.Tensor:
+        clipped = centered_clip(self._tensors, self._weights, **self._kwargs)
+        return clipped.result
+
+
+@dataclasses.dataclass(frozen=True)
+class CenteredClipResult:
+    result: torch.Tensor
+    n_clipped: torch.Tensor
+    last_step_delta: torch.Tensor
+
+
+def centered_clip(
+    input_tensors: torch.Tensor,
+    weights: torch.Tensor,
+    tau: float = 1.0,
+    n_iters: int = 20,
+    stop_delta: Optional[float] = None,
+) -> CenteredClipResult:
+    """
+    Optimized implementation of CenteredClip from [Karimireddy, 2021].
+    Intended to be used in a decentralized fashion as in [Gorbunov, 2021].
+
+    :stop_delta: Stop iterations early if the ``L_inf`` norm of the last step is less than ``stop_delta``.
+                 Note: if this option is used, the step norm calculations may increase the time per iteration by ~25%.
+
+    References:
+
+    [Karimireddy, 2021] Karimireddy, Sai Praneeth, Lie He, and Martin Jaggi. "Learning from history for byzantine
+    robust optimization." International Conference on Machine Learning. PMLR, 2021.
+
+    [Gorbunov, 2021] Gorbunov, Eduard, Alexander Borzunov, Michael Diskin, and Max Ryabinin.
+    "Secure Distributed Training at Scale." arXiv preprint arXiv:2106.11257 (2021).
+    """
+
+    with torch.no_grad():
+        n_peers = input_tensors.shape[0]
+        result_shape = input_tensors.shape[1:]
+
+        input_tensors = input_tensors.flatten(start_dim=1)
+        weights /= weights.sum()
+
+        # This finds medians faster than torch.median() and torch.quantile(q=0.5),
+        # see https://github.com/pytorch/pytorch/issues/51450
+        sorted_tensors = input_tensors.sort(dim=0).values
+        result = sorted_tensors[n_peers // 2].clone()
+        delta = None
+
+        diff = torch.sub(input_tensors, result, out=sorted_tensors)  # Reuse memory from `sorted_tensors`
+        for _ in range(n_iters):
+            norms = diff.norm(dim=1)
+            coeffs = weights * torch.minimum(torch.tensor(1.0), tau / norms)
+
+            if stop_delta is not None:
+                prev_diff = result[...] = diff[0]  # Reuse memory from `result`
+
+            # We only need to update `diff` (not `result`) between iterations
+            diff.addmm_(-coeffs.repeat(n_peers, 1), diff)
+
+            if stop_delta is not None:
+                delta = prev_diff.sub_(diff[0]).max()
+                if delta < stop_delta:
+                    break
+        torch.sub(input_tensors[0], diff[0], out=result)
+
+        return CenteredClipResult(
+            result=result.reshape(result_shape), n_clipped=(tau < norms).sum(), last_step_delta=delta
+        )

+ 4 - 1
hivemind/averaging/allreduce.py

@@ -4,6 +4,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Type
 
 
 import torch
 import torch
 
 
+from hivemind.averaging.accumulators import AccumulatorFactory
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import AllreduceException, TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
 from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase, StubBase
@@ -58,6 +59,7 @@ class AllReduceRunner(ServicerBase):
         tensors: Sequence[torch.Tensor],
         tensors: Sequence[torch.Tensor],
         ordered_peer_ids: Sequence[PeerID],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
+        accumulator_factory: AccumulatorFactory,
         weights: Optional[Sequence[float]] = None,
         weights: Optional[Sequence[float]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
         gathered: Optional[Dict[PeerID, Any]] = None,
@@ -97,7 +99,8 @@ class AllReduceRunner(ServicerBase):
         self.tensor_part_reducer = TensorPartReducer(
         self.tensor_part_reducer = TensorPartReducer(
             tuple(part.shape for part in self.parts_for_local_averaging),
             tuple(part.shape for part in self.parts_for_local_averaging),
             len(self.sender_peer_ids),
             len(self.sender_peer_ids),
-            self.sender_weights,
+            weights=self.sender_weights,
+            accumulator_factory=accumulator_factory,
         )
         )
 
 
     def __repr__(self):
     def __repr__(self):

+ 3 - 0
hivemind/averaging/averager.py

@@ -15,6 +15,7 @@ from typing import Any, AsyncIterator, Dict, Optional, Sequence, Tuple, Union
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+from hivemind.averaging.accumulators import AccumulatorFactory, MeanAccumulator
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.allreduce import AllreduceException, AllReduceRunner, AveragingMode, GroupID
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
@@ -112,6 +113,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         compression: CompressionBase = NoCompression(),
         compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         state_compression: CompressionBase = NoCompression(),
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
         tensor_infos: Optional[Sequence[CompressionInfo]] = None,
+        accumulator_factory: AccumulatorFactory = MeanAccumulator,
         bandwidth: Optional[float] = None,
         bandwidth: Optional[float] = None,
         min_vector_size: int = 0,
         min_vector_size: int = 0,
         auxiliary: bool = False,
         auxiliary: bool = False,
@@ -170,6 +172,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             compression=compression,
             compression=compression,
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,
             min_vector_size=min_vector_size,
             min_vector_size=min_vector_size,
+            accumulator_factory=accumulator_factory,
         )
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
         self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce

+ 15 - 9
hivemind/averaging/partition.py

@@ -8,6 +8,7 @@ from typing import AsyncIterable, AsyncIterator, Optional, Sequence, Tuple, Type
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+from hivemind.averaging.accumulators import AccumulatorFactory
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.compression import CompressionBase, CompressionInfo, NoCompression
 from hivemind.proto import runtime_pb2
 from hivemind.proto import runtime_pb2
 from hivemind.utils.asyncio import amap_in_executor
 from hivemind.utils.asyncio import amap_in_executor
@@ -171,16 +172,23 @@ class TensorPartReducer:
     :note: even if local peer is not sending data, local parts will be used for shape information
     :note: even if local peer is not sending data, local parts will be used for shape information
     """
     """
 
 
-    def __init__(self, part_shapes: Sequence[torch.Size], num_senders: int, weights: Optional[Sequence[float]] = None):
+    def __init__(
+        self,
+        part_shapes: Sequence[torch.Size],
+        num_senders: int,
+        *,
+        weights: Optional[Sequence[float]],
+        accumulator_factory: AccumulatorFactory,
+    ):
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.part_shapes, self.num_senders, self.num_parts = part_shapes, num_senders, len(part_shapes)
         self.weights = tuple(weights or (1 for _ in range(num_senders)))
         self.weights = tuple(weights or (1 for _ in range(num_senders)))
         assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
         assert len(self.weights) == self.num_senders, "The number of weights is inconsistent with num_senders"
         assert all(isinstance(weight, (int, float)) for weight in self.weights)
         assert all(isinstance(weight, (int, float)) for weight in self.weights)
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_index = -1  # index in local_parts of the part that should be loaded next
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
         self.current_part_accumulated_from = 0  # number of peers from which the current part was accumulated
-        self.accumulator = None  # this will contain the sum of current tensor part from group peers
-        self.denominator = 0.0  # total weight accumulated from all peers for current part
         self.current_part_future = asyncio.Future()
         self.current_part_future = asyncio.Future()
+        self.accumulator_factory = accumulator_factory
+        self.accumulator = None
         self.finished = asyncio.Event()
         self.finished = asyncio.Event()
         self.reset_accumulators()
         self.reset_accumulators()
 
 
@@ -194,8 +202,7 @@ class TensorPartReducer:
         self.current_part_index += 1
         self.current_part_index += 1
         self.current_part_accumulated_from = 0
         self.current_part_accumulated_from = 0
         self.current_part_future = asyncio.Future()
         self.current_part_future = asyncio.Future()
-        self.accumulator = torch.zeros(self.part_shapes[self.current_part_index])
-        self.denominator = 0.0
+        self.accumulator = self.accumulator_factory(self.part_shapes[self.current_part_index], self.num_senders)
 
 
     async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
     async def accumulate_part(self, sender_index: int, part_index: int, tensor_part: torch.Tensor) -> torch.Tensor:
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
         """Add vector part to accumulator, wait for all other vectors to be added, then return the average part"""
@@ -211,13 +218,12 @@ class TensorPartReducer:
 
 
         current_part_future = self.current_part_future
         current_part_future = self.current_part_future
 
 
-        self.accumulator.add_(tensor_part, alpha=self.weights[sender_index])
-        self.denominator += self.weights[sender_index]
+        self.accumulator.accumulate_part(tensor_part, self.weights[sender_index])
         self.current_part_accumulated_from += 1
         self.current_part_accumulated_from += 1
 
 
         assert self.current_part_accumulated_from <= self.num_senders
         assert self.current_part_accumulated_from <= self.num_senders
         if self.current_part_accumulated_from == self.num_senders:
         if self.current_part_accumulated_from == self.num_senders:
-            current_part_future.set_result(self.accumulator.div_(self.denominator))
+            current_part_future.set_result(self.accumulator.reduce())
             self.reset_accumulators()
             self.reset_accumulators()
         return await current_part_future
         return await current_part_future
 
 
@@ -225,7 +231,7 @@ class TensorPartReducer:
         if not self.finished.is_set():
         if not self.finished.is_set():
             if hasattr(self, "current_part_future"):
             if hasattr(self, "current_part_future"):
                 self.current_part_future.cancel()
                 self.current_part_future.cancel()
-                del self.accumulator
+                self.accumulator = None
             self.finished.set()
             self.finished.set()
 
 
     def __del__(self):
     def __del__(self):

+ 3 - 1
tests/test_allreduce.py

@@ -7,6 +7,7 @@ import pytest
 import torch
 import torch
 
 
 from hivemind import Quantile8BitQuantization, aenumerate
 from hivemind import Quantile8BitQuantization, aenumerate
+from hivemind.averaging.accumulators import MeanAccumulator
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.allreduce import AllReduceRunner, AveragingMode
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer
 from hivemind.compression import deserialize_torch_tensor
 from hivemind.compression import deserialize_torch_tensor
@@ -119,7 +120,7 @@ async def test_partitioning_asynchronous():
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
 async def test_reducer(num_senders: int, num_parts: int, synchronize_prob: float):
     tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
     tensor_part_shapes = [torch.Size([i]) for i in range(num_parts)]
-    reducer = TensorPartReducer(tensor_part_shapes, num_senders)
+    reducer = TensorPartReducer(tensor_part_shapes, num_senders, weights=None, accumulator_factory=MeanAccumulator)
 
 
     local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
     local_tensors_by_sender = [[torch.randn(i) for i in range(num_parts)] for j in range(num_senders)]
 
 
@@ -196,6 +197,7 @@ async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions,
             tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
             tensors=[x.clone() for x in tensors_by_peer[p2p.peer_id]],
             ordered_peer_ids=peers,
             ordered_peer_ids=peers,
             peer_fractions=peer_fractions,
             peer_fractions=peer_fractions,
+            accumulator_factory=MeanAccumulator,
             modes=peer_modes,
             modes=peer_modes,
             weights=averaging_weights,
             weights=averaging_weights,
             part_size_bytes=part_size_bytes,
             part_size_bytes=part_size_bytes,