Browse Source

Merge branch 'master' into server-p2p

Pavel Samygin 3 năm trước cách đây
mục cha
commit
106ae3db16

+ 1 - 1
.github/workflows/check-style.yml

@@ -13,7 +13,7 @@ jobs:
       - uses: psf/black@stable
         with:
           options: "--check --diff"
-          version: "22.1.0"
+          version: "22.3.0"
   isort:
     runs-on: ubuntu-latest
     steps:

+ 6 - 2
.github/workflows/run-tests.yml

@@ -12,7 +12,7 @@ jobs:
     strategy:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 10
+    timeout-minutes: 12
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python
@@ -42,6 +42,10 @@ jobs:
     timeout-minutes: 10
     steps:
       - uses: actions/checkout@v2
+      - uses: actions/setup-go@v3
+        with:
+          go-version: '1.16'
+          check-latest: true
       - name: Set up Python
         uses: actions/setup-python@v2
         with:
@@ -67,7 +71,7 @@ jobs:
   codecov_in_develop_mode:
 
     runs-on: ubuntu-latest
-    timeout-minutes: 10
+    timeout-minutes: 12
     steps:
       - uses: actions/checkout@v2
       - name: Set up Python

+ 1 - 1
README.md

@@ -63,7 +63,7 @@ By default, hivemind uses the precompiled binary of
 the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues
 or want to build the binary yourself, you can recompile it by running `pip install . --global-option="--buildgo"`.
 Before running the compilation, please ensure that your machine has a recent version
-of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
+of [Go toolchain](https://golang.org/doc/install) (1.15 or 1.16 are supported).
 
 ### System requirements
 

+ 9 - 0
examples/albert/README.md

@@ -130,6 +130,15 @@ monitors on different servers and list all of them as `--initial_peers`. The sys
 as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
 monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
+By default, the training monitor changes its address on restart, so you may launch two monitors on the same machine.
+If you'd like to fix the monitor's address (e.g., before sending it to your collaborators),
+you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a path for storing the identity file
+(which allows [libp2p](https://libp2p.io/) to reuse the same peer ID after restart). You may do that like this:
+
+```bash
+./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT --host_maddrs /ip4/0.0.0.0/tcp/31337 --identity_path ./identity.key
+```
+
 ### Tuning for hardware/network
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept

+ 1 - 1
examples/albert/arguments.py

@@ -38,7 +38,7 @@ class BaseTrainingArguments:
         default=None,
         metadata={
             "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
-            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+            "If the file does not exist yet, writes a new private key to this file."
         },
     )
 

+ 2 - 4
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
+from typing import AsyncIterator, Optional, Sequence, Set, Tuple, Type
 
 import torch
 
@@ -50,7 +50,6 @@ class AllReduceRunner(ServicerBase):
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param gathered: additional user-defined data collected from this group
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
@@ -73,7 +72,6 @@ class AllReduceRunner(ServicerBase):
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[PeerID, Any]] = None,
         sender_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         **kwargs,
@@ -99,7 +97,7 @@ class AllReduceRunner(ServicerBase):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
-        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+        self.modes, self.peer_fractions = modes, peer_fractions
 
         if weight is None:
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)

+ 67 - 64
hivemind/averaging/averager.py

@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
     achain,
-    afirst,
     aiter_with_timeout,
     anext,
     as_aiter,
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
 
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     serializer = MSGPackSerializer
@@ -207,7 +200,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
@@ -309,8 +302,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
 
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
@@ -441,7 +434,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             while not step.done():
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -458,17 +451,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group")
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
-
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                    with self._register_allreduce_group(group_info):
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._aggregate_with_group(
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
                 except (
                     AllreduceException,
@@ -503,8 +500,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                 )
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            self._running_groups[group_info.group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
+            if maybe_future is not None and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
@@ -519,47 +529,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             )
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                allreduce = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id,
-                    tensors=local_tensors,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
-
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
-
-                return allreduce.gathered
+                await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
+                return user_gathered
         except BaseException as e:
             if isinstance(e, Exception):
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
+    async def _run_allreduce_inplace_(
+        self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
+    ):
+        """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
 
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -586,13 +588,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
 

+ 20 - 7
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 import torch
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.dht import DHT
+from hivemind.utils import DHTExpiration, get_logger
 
 logger = get_logger(__name__)
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -36,6 +40,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
@@ -69,12 +74,13 @@ class GradientAverager(DecentralizedAverager):
         self,
         parameters: Iterable[torch.nn.Parameter],
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         warn: bool = True,
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +101,16 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
 
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if not averaged_grads:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
+            else:
+                if all(
+                    params_grad.size() == grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                ):
+                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 15 - 5
hivemind/optim/optimizer.py

@@ -11,8 +11,9 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
     LRSchedulerBase,
@@ -147,6 +148,7 @@ class Optimizer(torch.optim.Optimizer):
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -187,6 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
+        grad_averager_factory: Optional[GradientAveragerFactory] = None,
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
@@ -201,7 +204,8 @@ class Optimizer(torch.optim.Optimizer):
 
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
-        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        if offload_optimizer is None:
+            offload_optimizer = params is not None and not use_local_updates
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
@@ -225,6 +229,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided grad_averager_factory will not be used"
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -257,7 +264,7 @@ class Optimizer(torch.optim.Optimizer):
         )
         if not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+                grad_averager_factory, reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression
             )
         else:
             self.grad_averager = None
@@ -290,9 +297,10 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
         )
 
-    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = GradientAverager(
+        factory = factory if factory is not None else GradientAverager
+        grad_averager = factory(
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
@@ -684,6 +692,8 @@ class Optimizer(torch.optim.Optimizer):
             while True:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    if self.grad_averager is not None:
+                        self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                 except KeyboardInterrupt:
                     raise

+ 223 - 0
hivemind/optim/power_sgd_averager.py

@@ -0,0 +1,223 @@
+import asyncio
+import contextlib
+import multiprocessing as mp
+from enum import Enum
+from typing import Any, Iterable, Optional, Sequence
+
+import torch
+
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
+class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple steps to kick in.
+
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: rank of compressed gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
+      otherwise aggregate tensors as is
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        averager_rank: int,
+        *,
+        dht: DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        min_compression_ratio: float = 0.5,
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
+        **kwargs,
+    ):
+        self.rank = averager_rank
+        self.parameters = tuple(parameters)
+        self._uncompressed_gradients_indexes = set(
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters are left after factorization
+        )
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+
+        super().__init__(
+            self.parameters,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            averaged_grads=averaged_grads,
+            **kwargs,
+        )
+
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
+                    m.add_(grad.to(m.device))
+
+                ps = [
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
+                ]
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
+
+                await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
+
+                for p in ps:
+                    orthogonalize_(p)
+
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+
+                phase_q_tensors = self._qs + [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                await self._run_allreduce_inplace_(
+                    phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
+                )
+
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
+
+                return user_gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            grad_averager_buffers = [q for q in self._qs]
+            grad_averager_buffers_infos = [
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
+            ]
+
+        metadata = dict(group_bits=self.get_group_bits())
+        return metadata, grad_averager_buffers, grad_averager_buffers_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        logger.info("Starting loading gradient averager buffers from peers")
+
+        if len(flat_tensors) != len(self._qs):
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for local_q, loaded_q in zip(self._qs, flat_tensors):
+                local_q.copy_(loaded_q, non_blocking=True)

+ 38 - 12
hivemind/p2p/p2p_daemon.py

@@ -3,6 +3,7 @@ import json
 import logging
 import os
 import secrets
+import warnings
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
@@ -17,8 +18,10 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.proto import crypto_pb2
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 logger = get_logger(__name__)
@@ -89,16 +92,16 @@ class P2P:
         identity_path: Optional[str] = None,
         idle_timeout: float = 30,
         nat_port_map: bool = True,
-        quic: bool = False,
         relay_hop_limit: int = 0,
         startup_timeout: float = 15,
         tls: bool = True,
         use_auto_relay: bool = False,
         use_ipfs: bool = False,
         use_relay: bool = True,
-        use_relay_hop: bool = False,
-        use_relay_discovery: bool = False,
         persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+        quic: Optional[bool] = None,
+        use_relay_hop: Optional[bool] = None,
+        use_relay_discovery: Optional[bool] = None,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -112,20 +115,20 @@ class P2P:
                          Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
-        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
-                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param identity_path: Path to a private key file. If defined, makes the peer ID deterministic.
+                              If the file does not exist yet, writes a new private key to this file.
         :param idle_timeout: kill daemon if client has been idle for a given number of
                              seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
-        :param quic: Enables the QUIC transport
         :param relay_hop_limit: sets the hop limit for hop relays
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param tls: Enables TLS1.3 channel security protocol
         :param use_auto_relay: enables autorelay
         :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
-        :param use_relay_hop: enables hop for relay
-        :param use_relay_discovery: enables passive discovery for relay
+        :param quic: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_hop: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_discovery: Deprecated, has no effect since libp2p 0.17.0
         :return: a wrapper for the p2p daemon
         """
 
@@ -133,6 +136,14 @@ class P2P:
             initial_peers and use_ipfs
         ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
 
+        if not all(arg is None for arg in [quic, use_relay_hop, use_relay_discovery]):
+            warnings.warn(
+                "Parameters `quic`, `use_relay_hop`, and `use_relay_discovery` of hivemind.P2P "
+                "have no effect since libp2p 0.17.0 and will be removed in hivemind 1.2.0+",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
         self = cls()
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
@@ -147,7 +158,7 @@ class P2P:
                     raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
-        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
+        process_kwargs = cls.DHT_MODE_MAPPING[dht_mode].copy()
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
         for param, value in [
             ("bootstrapPeers", initial_peers),
@@ -156,7 +167,11 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+
         if identity_path is not None:
+            if not os.path.isfile(identity_path):
+                logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+                self.generate_identity(identity_path)
             process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
@@ -168,10 +183,7 @@ class P2P:
             idleTimeout=f"{idle_timeout}s",
             listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
-            quic=quic,
             relay=use_relay,
-            relayDiscovery=use_relay_discovery,
-            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             tls=tls,
             persistentConnMaxMsgSize=persistent_conn_max_msg_size,
@@ -205,6 +217,20 @@ class P2P:
         await self._ping_daemon()
         return self
 
+    @staticmethod
+    def generate_identity(identity_path: str) -> None:
+        private_key = RSAPrivateKey()
+        protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes())
+
+        try:
+            with open(identity_path, "wb") as f:
+                f.write(protobuf.SerializeToString())
+        except FileNotFoundError:
+            raise FileNotFoundError(
+                f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist"
+            )
+        os.chmod(identity_path, 0o400)
+
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """

+ 24 - 0
hivemind/proto/crypto.proto

@@ -0,0 +1,24 @@
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package crypto.pb;
+
+enum KeyType {
+  RSA = 0;
+  Ed25519 = 1;
+  Secp256k1 = 2;
+  ECDSA = 3;
+}
+
+message PublicKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}
+
+message PrivateKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}

+ 3 - 3
hivemind/proto/p2pd.proto

@@ -1,6 +1,6 @@
-//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
-//Licence: MIT
-//Author: Kevin Mai-Husan Chia
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
 
 syntax = "proto2";
 

+ 9 - 6
hivemind/utils/crypto.py

@@ -60,19 +60,22 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
 
+    def to_bytes(self) -> bytes:
+        return self._private_key.private_bytes(
+            encoding=serialization.Encoding.DER,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption(),
+        )
+
     def __getstate__(self):
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
-        state["_private_key"] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption(),
-        )
+        state["_private_key"] = self.to_bytes()
         return state
 
     def __setstate__(self, state):
         self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._private_key = serialization.load_der_private_key(self._private_key, password=None)
 
 
 class RSAPublicKey(PublicKey):

+ 24 - 0
hivemind/utils/math.py

@@ -0,0 +1,24 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.black]
 line-length = 119
-required-version = "22.1.0"
+required-version = "22.3.0"
 
 [tool.isort]
 profile = "black"

+ 1 - 1
requirements-dev.txt

@@ -6,6 +6,6 @@ coverage==6.0.2  # see https://github.com/pytest-dev/pytest-cov/issues/520
 tqdm
 scikit-learn
 torchvision
-black==22.1.0
+black==22.3.0
 isort==5.10.1
 psutil

+ 37 - 29
setup.py

@@ -3,7 +3,6 @@ import glob
 import hashlib
 import os
 import re
-import shlex
 import subprocess
 import tarfile
 import tempfile
@@ -14,20 +13,25 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.6"
-P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
-LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
-P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
+P2PD_VERSION = "v0.3.8"
+
+P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
+
+# The value is sha256 of the binary from the release page
+EXECUTABLES = {
+    "p2pd": "785058526d993f699c674dc2f9b66d565a52315a18b79b629998fab3ebd8e20f",
+}
+
 
 here = os.path.abspath(os.path.dirname(__file__))
 
 
-def md5(fname, chunk_size=4096):
-    hash_md5 = hashlib.md5()
-    with open(fname, "rb") as f:
-        for chunk in iter(lambda: f.read(chunk_size), b""):
-            hash_md5.update(chunk)
-    return hash_md5.hexdigest()
+def sha256(path):
+    if not os.path.exists(path):
+        return None
+    with open(path, "rb") as f:
+        return hashlib.sha256(f.read()).hexdigest()
 
 
 def proto_compile(output_path):
@@ -64,32 +68,36 @@ def build_p2p_daemon():
 
     with tempfile.TemporaryDirectory() as tempdir:
         dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
-        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+        urllib.request.urlretrieve(P2PD_SOURCE_URL, dest)
 
         with tarfile.open(dest, "r:gz") as tar:
             tar.extractall(tempdir)
 
-        result = subprocess.run(
-            f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
-            cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION[1:]}", "p2pd"),
-            shell=True,
-        )
-
-        if result.returncode:
-            raise RuntimeError(
-                "Failed to build or install libp2p-daemon:" f" exited with status code: {result.returncode}"
+        for executable in EXECUTABLES:
+            result = subprocess.run(
+                ["go", "build", "-o", os.path.join(here, "hivemind", "hivemind_cli", executable)],
+                cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION.lstrip('v')}", executable),
             )
+            if result.returncode != 0:
+                raise RuntimeError(f"Failed to build {executable}: exited with status code: {result.returncode}")
 
 
 def download_p2p_daemon():
-    install_path = os.path.join(here, "hivemind", "hivemind_cli")
-    binary_path = os.path.join(install_path, "p2pd")
-    if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
-        print("Downloading Peer to Peer Daemon")
-        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
-        os.chmod(binary_path, 0o777)
-        if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
+    for executable, expected_hash in EXECUTABLES.items():
+        binary_path = os.path.join(here, "hivemind", "hivemind_cli", executable)
+
+        if sha256(binary_path) != expected_hash:
+            binary_url = os.path.join(P2PD_BINARY_URL, executable)
+            print(f"Downloading {binary_url}")
+
+            urllib.request.urlretrieve(binary_url, binary_path)
+            os.chmod(binary_path, 0o777)
+
+            actual_hash = sha256(binary_path)
+            if actual_hash != expected_hash:
+                raise RuntimeError(
+                    f"The sha256 checksum for {executable} does not match (expected: {expected_hash}, actual: {actual_hash})"
+                )
 
 
 class BuildPy(build_py):

+ 14 - 12
tests/test_allreduce_fault_tolerance.py

@@ -35,7 +35,7 @@ class FaultyAverager(hivemind.DecentralizedAverager):
         self.fault = fault
         super().__init__(*args, **kwargs)
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
@@ -60,24 +60,26 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
-                    gathered=user_gathered,
                     modes=modes,
                     fault=self.fault,
                     **kwargs,
                 )
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
-                return allreduce.gathered
+                return user_gathered
         except BaseException as e:
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")

+ 62 - 12
tests/test_optimizer.py

@@ -11,24 +11,31 @@ import torch.nn.functional as F
 
 import hivemind
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
 
@@ -38,12 +45,12 @@ def test_grad_averager():
     for i in range(10):
         time.sleep(0.1)
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -51,11 +58,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
     averager1.step(control=control1, wait=False)
@@ -162,7 +169,11 @@ def test_load_state_from_peers():
     )
 
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        **common_kwargs,
     )
 
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
@@ -286,12 +297,45 @@ def test_progress_tracker():
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
+    # fmt: off
+    [
+        (False, False, False, False, False),
+        (False, True, False, False, False),
+        (False, True, True, True, False),
+        (False, False, False, False, True),
+        (False, True, True, True, True),
+        (False, True, True, False, True),
+        (True, False, False, False, False),
+        (True, True, False, False, False,),
+    ],
+    # fmt: on
+)
 def test_optimizer(
+    use_local_updates: bool,
+    delay_state_averaging: bool,
+    delay_optimizer_step: bool,
+    delay_grad_averaging: bool,
+    reuse_grad_buffers: bool,
+):
+    _test_optimizer(
+        use_local_updates=use_local_updates,
+        delay_state_averaging=delay_state_averaging,
+        delay_grad_averaging=delay_grad_averaging,
+        delay_optimizer_step=delay_optimizer_step,
+        reuse_grad_buffers=reuse_grad_buffers,
+    )
+
+
+def _test_optimizer(
     num_peers: int = 1,
     num_clients: int = 0,
     target_batch_size: int = 32,
     total_epochs: int = 3,
+    use_local_updates: bool = False,
     reuse_grad_buffers: bool = True,
+    delay_state_averaging: bool = True,
     delay_grad_averaging: bool = True,
     delay_optimizer_step: bool = True,
     average_state_every: int = 1,
@@ -319,9 +363,11 @@ def test_optimizer(
             dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
             tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
             averager_opts=dict(request_timeout=0.5),
+            use_local_updates=use_local_updates,
             matchmaking_time=1.0,
             averaging_timeout=5.0,
             reuse_grad_buffers=reuse_grad_buffers,
+            delay_state_averaging=delay_state_averaging,
             delay_grad_averaging=delay_grad_averaging,
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
@@ -380,6 +426,10 @@ def test_optimizer(
     assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
 
     assert not optimizer.state_averager.is_alive()
-    assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
+    if not use_local_updates:
+        assert not optimizer.grad_averager.is_alive()
+    else:
+        assert optimizer.grad_averager is None
+
     assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()

+ 29 - 2
tests/test_p2p_daemon.py

@@ -1,6 +1,8 @@
 import asyncio
 import multiprocessing as mp
+import os
 import subprocess
+import tempfile
 from contextlib import closing
 from functools import partial
 from typing import List
@@ -45,6 +47,31 @@ async def test_startup_error_message():
         await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
+@pytest.mark.asyncio
+async def test_identity():
+    with tempfile.TemporaryDirectory() as tempdir:
+        id1_path = os.path.join(tempdir, "id1")
+        id2_path = os.path.join(tempdir, "id2")
+        p2ps = await asyncio.gather(*[P2P.create(identity_path=path) for path in [None, None, id1_path, id2_path]])
+
+        # We create the second daemon with id2 separately
+        # to avoid a race condition while saving a newly generated identity
+        p2ps.append(await P2P.create(identity_path=id2_path))
+
+        # Using the same identity (if any) should lead to the same peer ID
+        assert p2ps[-2].peer_id == p2ps[-1].peer_id
+
+        # The rest of peer IDs should be different
+        peer_ids = {instance.peer_id for instance in p2ps}
+        assert len(peer_ids) == 4
+
+        for instance in p2ps:
+            await instance.shutdown()
+
+    with pytest.raises(FileNotFoundError, match=r"The directory.+does not exist"):
+        P2P.generate_identity(id1_path)
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [
@@ -55,11 +82,11 @@ async def test_startup_error_message():
 )
 @pytest.mark.asyncio
 async def test_transports(host_maddrs: List[Multiaddr]):
-    server = await P2P.create(quic=True, host_maddrs=host_maddrs)
+    server = await P2P.create(host_maddrs=host_maddrs)
     peers = await server.list_peers()
     assert len(peers) == 0
 
-    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
+    client = await P2P.create(host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
 
     peers = await client.list_peers()