Преглед на файлове

Add PowerSGD for compressed gradient averaging (#432)

This PR implements the PowerSGD algorithm (https://arxiv.org/abs/1905.13727) for compressed gradient averaging.

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: artek0chumak <artek0chumak@nvidia.ru-central1.internal>
Artem Chumachenko преди 3 години
родител
ревизия
8387718c2a

+ 2 - 4
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
+from typing import AsyncIterator, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
@@ -50,7 +50,6 @@ class AllReduceRunner(ServicerBase):
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param gathered: additional user-defined data collected from this group
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
@@ -73,7 +72,6 @@ class AllReduceRunner(ServicerBase):
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[PeerID, Any]] = None,
         sender_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         **kwargs,
@@ -99,7 +97,7 @@ class AllReduceRunner(ServicerBase):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
-        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+        self.modes, self.peer_fractions = modes, peer_fractions
 
         if weight is None:
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)

+ 67 - 64
hivemind/averaging/averager.py

@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
-    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
 
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     serializer = MSGPackSerializer
@@ -207,7 +200,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
@@ -309,8 +302,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
 
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
@@ -441,7 +434,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             while not step.done():
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -458,17 +451,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group")
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
-
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                    with self._register_allreduce_group(group_info):
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._aggregate_with_group(
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
                 except (
                     AllreduceException,
@@ -503,8 +500,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            self._running_groups[group_info.group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
+            if maybe_future is not None and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
@@ -519,47 +529,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             )
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                allreduce = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id,
-                    tensors=local_tensors,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
-
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
-
-                return allreduce.gathered
+                await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
+                return user_gathered
         except BaseException as e:
             if isinstance(e, Exception):
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
+    async def _run_allreduce_inplace_(
+        self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
+    ):
+        """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -586,13 +588,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 

+ 20 - 7
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 import torch
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.dht import DHT
+from hivemind.utils import DHTExpiration, get_logger
 
 logger = get_logger(__name__)
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -36,6 +40,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
@@ -69,12 +74,13 @@ class GradientAverager(DecentralizedAverager):
         self,
         parameters: Iterable[torch.nn.Parameter],
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +101,16 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
 
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if not averaged_grads:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
+            else:
+                if all(
+                    params_grad.size() == grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                ):
+                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 12 - 5
hivemind/optim/optimizer.py

@@ -11,8 +11,9 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
     LRSchedulerBase,
@@ -147,6 +148,7 @@ class Optimizer(torch.optim.Optimizer):
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -187,6 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
+        grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
@@ -226,6 +229,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided grad_averager_factory will not be used"
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -256,9 +262,9 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             **averager_opts or {},
         )
-        if not use_local_updates:
+        if grad_averager_factory is not None and not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+                reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
             )
         else:
             self.grad_averager = None
@@ -291,9 +297,9 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = GradientAverager(
+        grad_averager = grad_averager_factory(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -685,6 +691,7 @@ class Optimizer(torch.optim.Optimizer):
             while True:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                 except KeyboardInterrupt:
                     raise

+ 223 - 0
hivemind/optim/power_sgd_averager.py

@@ -0,0 +1,223 @@
+import asyncio
+import contextlib
+import multiprocessing as mp
+from enum import Enum
+from typing import Any, Iterable, Optional, Sequence
+
+import torch
+
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
+class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple steps to kick in.
+
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: rank of compressed gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
+      otherwise aggregate tensors as is
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        averager_rank: int,
+        *,
+        dht: DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        min_compression_ratio: float = 0.5,
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
+        **kwargs,
+    ):
+        self.rank = averager_rank
+        self.parameters = tuple(parameters)
+        self._uncompressed_gradients_indexes = set(
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters are left after factorization
+        )
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+
+        super().__init__(
+            self.parameters,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            averaged_grads=averaged_grads,
+            **kwargs,
+        )
+
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
+                    m.add_(grad.to(m.device))
+
+                ps = [
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
+                ]
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
+
+                await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
+
+                for p in ps:
+                    orthogonalize_(p)
+
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+
+                phase_q_tensors = self._qs + [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                await self._run_allreduce_inplace_(
+                    phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
+                )
+
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
+
+                return user_gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            grad_averager_buffers = [q for q in self._qs]
+            grad_averager_buffers_infos = [
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
+            ]
+
+        metadata = dict(group_bits=self.get_group_bits())
+        return metadata, grad_averager_buffers, grad_averager_buffers_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        logger.info("Starting loading gradient averager buffers from peers")
+
+        if len(flat_tensors) != len(self._qs):
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for local_q, loaded_q in zip(self._qs, flat_tensors):
+                local_q.copy_(loaded_q, non_blocking=True)

+ 24 - 0
hivemind/utils/math.py

@@ -0,0 +1,24 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims

+ 14 - 12
tests/test_allreduce_fault_tolerance.py

@@ -35,7 +35,7 @@ class FaultyAverager(hivemind.DecentralizedAverager):
         self.fault = fault
         super().__init__(*args, **kwargs)
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
@@ -60,24 +60,26 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    gathered=user_gathered,
                     modes=modes,
                     fault=self.fault,
                     **kwargs,
                 )
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
-                return allreduce.gathered
+                return user_gathered
         except BaseException as e:
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")

+ 22 - 11
tests/test_optimizer.py

@@ -11,24 +11,31 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
 
@@ -38,12 +45,12 @@ def test_grad_averager():
     for i in range(10):
         time.sleep(0.1)
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -51,11 +58,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
     averager1.step(control=control1, wait=False)
@@ -162,7 +169,11 @@ def test_load_state_from_peers():
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        **common_kwargs,
     )
 
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)