浏览代码

Add PowerSGD for compressed gradient averaging (#432)

This PR implements the PowerSGD algorithm (https://arxiv.org/abs/1905.13727) for compressed gradient averaging.

Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: artek0chumak <artek0chumak@nvidia.ru-central1.internal>
Artem Chumachenko 3 年之前
父节点
当前提交
8387718c2a

+ 2 - 4
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 from enum import Enum
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
+from typing import AsyncIterator, Optional, Sequence, Set, Tuple, Type
 
 
 import torch
 import torch
 
 
@@ -50,7 +50,6 @@ class AllReduceRunner(ServicerBase):
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param gathered: additional user-defined data collected from this group
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
@@ -73,7 +72,6 @@ class AllReduceRunner(ServicerBase):
         ordered_peer_ids: Sequence[PeerID],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[PeerID, Any]] = None,
         sender_timeout: Optional[float] = None,
         sender_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         **kwargs,
         **kwargs,
@@ -99,7 +97,7 @@ class AllReduceRunner(ServicerBase):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
 
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
-        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+        self.modes, self.peer_fractions = modes, peer_fractions
 
 
         if weight is None:
         if weight is None:
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)

+ 67 - 64
hivemind/averaging/averager.py

@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
-    afirst,
     aiter_with_timeout,
     aiter_with_timeout,
     anext,
     anext,
     as_aiter,
     as_aiter,
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
     """
 
 
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     _p2p: P2P
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
@@ -207,7 +200,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
             reducer_timeout=reducer_timeout,
         )
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
 
@@ -309,8 +302,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
 
 
                 self._state_updated = asyncio.Event()
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
                 logger.debug(e, exc_info=True)
@@ -441,7 +434,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             while not step.done():
             while not step.done():
                 try:
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -458,17 +451,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group")
                         raise AllreduceException("Averaging step failed: could not find a group")
 
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
-
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                    with self._register_allreduce_group(group_info):
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._aggregate_with_group(
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
 
                 except (
                 except (
                     AllreduceException,
                     AllreduceException,
@@ -503,8 +500,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                     )
                 )
                 )
 
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            self._running_groups[group_info.group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
+            if maybe_future is not None and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
@@ -519,47 +529,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             )
             )
 
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                allreduce = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id,
-                    tensors=local_tensors,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
-
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
-
-                return allreduce.gathered
+                await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
+                return user_gathered
         except BaseException as e:
         except BaseException as e:
             if isinstance(e, Exception):
             if isinstance(e, Exception):
                 logger.exception(e)
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
+    async def _run_allreduce_inplace_(
+        self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
+    ):
+        """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -586,13 +588,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
             return
 
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
             yield message
 
 

+ 20 - 7
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 
 import torch
 import torch
 
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.dht import DHT
+from hivemind.utils import DHTExpiration, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
 class GradientAverager(DecentralizedAverager):
     """
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -36,6 +40,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
 
 
@@ -69,12 +74,13 @@ class GradientAverager(DecentralizedAverager):
         self,
         self,
         parameters: Iterable[torch.nn.Parameter],
         parameters: Iterable[torch.nn.Parameter],
         *,
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         prefix: str,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         client_mode: bool = None,
         warn: bool = True,
         warn: bool = True,
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
         **kwargs,
     ):
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +101,16 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
         self._new_averaged_grads = False
 
 
         with torch.no_grad():
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if not averaged_grads:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
+            else:
+                if all(
+                    params_grad.size() == grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                ):
+                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 12 - 5
hivemind/optim/optimizer.py

@@ -11,8 +11,9 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
 from hivemind.optim.state_averager import (
     LRSchedulerBase,
     LRSchedulerBase,
@@ -147,6 +148,7 @@ class Optimizer(torch.optim.Optimizer):
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -187,6 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
         grad_compression: CompressionBase = NoCompression(),
+        grad_averager_factory: Optional[GradientAveragerFactory] = GradientAverager,
         state_averaging_compression: CompressionBase = NoCompression(),
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
@@ -226,6 +229,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided grad_averager_factory will not be used"
 
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -256,9 +262,9 @@ class Optimizer(torch.optim.Optimizer):
             extra_tensors=extra_tensors,
             extra_tensors=extra_tensors,
             **averager_opts or {},
             **averager_opts or {},
         )
         )
-        if not use_local_updates:
+        if grad_averager_factory is not None and not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+                reuse_grad_buffers=reuse_grad_buffers, grad_averager_factory=grad_averager_factory
             )
             )
         else:
         else:
             self.grad_averager = None
             self.grad_averager = None
@@ -291,9 +297,9 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
             **kwargs,
         )
         )
 
 
-    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, grad_averager_factory, **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = GradientAverager(
+        grad_averager = grad_averager_factory(
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
@@ -685,6 +691,7 @@ class Optimizer(torch.optim.Optimizer):
             while True:
             while True:
                 try:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                     break
                 except KeyboardInterrupt:
                 except KeyboardInterrupt:
                     raise
                     raise

+ 223 - 0
hivemind/optim/power_sgd_averager.py

@@ -0,0 +1,223 @@
+import asyncio
+import contextlib
+import multiprocessing as mp
+from enum import Enum
+from typing import Any, Iterable, Optional, Sequence
+
+import torch
+
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
+class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple steps to kick in.
+
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: rank of compressed gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
+      otherwise aggregate tensors as is
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        averager_rank: int,
+        *,
+        dht: DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        min_compression_ratio: float = 0.5,
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
+        **kwargs,
+    ):
+        self.rank = averager_rank
+        self.parameters = tuple(parameters)
+        self._uncompressed_gradients_indexes = set(
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters are left after factorization
+        )
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+
+        super().__init__(
+            self.parameters,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            averaged_grads=averaged_grads,
+            **kwargs,
+        )
+
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
+                    m.add_(grad.to(m.device))
+
+                ps = [
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
+                ]
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
+
+                await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
+
+                for p in ps:
+                    orthogonalize_(p)
+
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+
+                phase_q_tensors = self._qs + [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                await self._run_allreduce_inplace_(
+                    phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
+                )
+
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
+
+                return user_gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            grad_averager_buffers = [q for q in self._qs]
+            grad_averager_buffers_infos = [
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
+            ]
+
+        metadata = dict(group_bits=self.get_group_bits())
+        return metadata, grad_averager_buffers, grad_averager_buffers_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        logger.info("Starting loading gradient averager buffers from peers")
+
+        if len(flat_tensors) != len(self._qs):
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for local_q, loaded_q in zip(self._qs, flat_tensors):
+                local_q.copy_(loaded_q, non_blocking=True)

+ 24 - 0
hivemind/utils/math.py

@@ -0,0 +1,24 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims

+ 14 - 12
tests/test_allreduce_fault_tolerance.py

@@ -35,7 +35,7 @@ class FaultyAverager(hivemind.DecentralizedAverager):
         self.fault = fault
         self.fault = fault
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
@@ -60,24 +60,26 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     tensors=local_tensors,
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
-                    gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     fault=self.fault,
                     fault=self.fault,
                     **kwargs,
                     **kwargs,
                 )
                 )
 
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
 
-                return allreduce.gathered
+                return user_gathered
         except BaseException as e:
         except BaseException as e:
             logger.exception(e)
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")

+ 22 - 11
tests/test_optimizer.py

@@ -11,24 +11,31 @@ import torch.nn.functional as F
 
 
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
     )
 
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
     )
 
 
@@ -38,12 +45,12 @@ def test_grad_averager():
     for i in range(10):
     for i in range(10):
         time.sleep(0.1)
         time.sleep(0.1)
         if i % 3 == 0:
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
             model1.zero_grad()
         else:
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -51,11 +58,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
 
     averager1.step(control=control1, wait=False)
     averager1.step(control=control1, wait=False)
@@ -162,7 +169,11 @@ def test_load_state_from_peers():
     )
     )
 
 
     avgr1 = TrainingStateAverager(
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        **common_kwargs,
     )
     )
 
 
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)