Explorar o código

Merge branch 'master' into server-p2p

Pavel Samygin %!s(int64=3) %!d(string=hai) anos
pai
achega
106ae3db16

+ 1 - 1
.github/workflows/check-style.yml

@@ -13,7 +13,7 @@ jobs:
       - uses: psf/black@stable
       - uses: psf/black@stable
         with:
         with:
           options: "--check --diff"
           options: "--check --diff"
-          version: "22.1.0"
+          version: "22.3.0"
   isort:
   isort:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
     steps:
     steps:

+ 6 - 2
.github/workflows/run-tests.yml

@@ -12,7 +12,7 @@ jobs:
     strategy:
     strategy:
       matrix:
       matrix:
         python-version: [ 3.7, 3.8, 3.9 ]
         python-version: [ 3.7, 3.8, 3.9 ]
-    timeout-minutes: 10
+    timeout-minutes: 12
     steps:
     steps:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - name: Set up Python
       - name: Set up Python
@@ -42,6 +42,10 @@ jobs:
     timeout-minutes: 10
     timeout-minutes: 10
     steps:
     steps:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
+      - uses: actions/setup-go@v3
+        with:
+          go-version: '1.16'
+          check-latest: true
       - name: Set up Python
       - name: Set up Python
         uses: actions/setup-python@v2
         uses: actions/setup-python@v2
         with:
         with:
@@ -67,7 +71,7 @@ jobs:
   codecov_in_develop_mode:
   codecov_in_develop_mode:
 
 
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
-    timeout-minutes: 10
+    timeout-minutes: 12
     steps:
     steps:
       - uses: actions/checkout@v2
       - uses: actions/checkout@v2
       - name: Set up Python
       - name: Set up Python

+ 1 - 1
README.md

@@ -63,7 +63,7 @@ By default, hivemind uses the precompiled binary of
 the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues
 the [go-libp2p-daemon](https://github.com/learning-at-home/go-libp2p-daemon) library. If you face compatibility issues
 or want to build the binary yourself, you can recompile it by running `pip install . --global-option="--buildgo"`.
 or want to build the binary yourself, you can recompile it by running `pip install . --global-option="--buildgo"`.
 Before running the compilation, please ensure that your machine has a recent version
 Before running the compilation, please ensure that your machine has a recent version
-of [Go toolchain](https://golang.org/doc/install) (1.15 or higher).
+of [Go toolchain](https://golang.org/doc/install) (1.15 or 1.16 are supported).
 
 
 ### System requirements
 ### System requirements
 
 

+ 9 - 0
examples/albert/README.md

@@ -130,6 +130,15 @@ monitors on different servers and list all of them as `--initial_peers`. The sys
 as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
 as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
 monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
 
+By default, the training monitor changes its address on restart, so you may launch two monitors on the same machine.
+If you'd like to fix the monitor's address (e.g., before sending it to your collaborators),
+you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a path for storing the identity file
+(which allows [libp2p](https://libp2p.io/) to reuse the same peer ID after restart). You may do that like this:
+
+```bash
+./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT --host_maddrs /ip4/0.0.0.0/tcp/31337 --identity_path ./identity.key
+```
+
 ### Tuning for hardware/network
 ### Tuning for hardware/network
 
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept

+ 1 - 1
examples/albert/arguments.py

@@ -38,7 +38,7 @@ class BaseTrainingArguments:
         default=None,
         default=None,
         metadata={
         metadata={
             "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
             "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
-            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+            "If the file does not exist yet, writes a new private key to this file."
         },
         },
     )
     )
 
 

+ 2 - 4
hivemind/averaging/allreduce.py

@@ -1,6 +1,6 @@
 import asyncio
 import asyncio
 from enum import Enum
 from enum import Enum
-from typing import Any, AsyncIterator, Dict, Optional, Sequence, Set, Tuple, Type
+from typing import AsyncIterator, Optional, Sequence, Set, Tuple, Type
 
 
 import torch
 import torch
 
 
@@ -50,7 +50,6 @@ class AllReduceRunner(ServicerBase):
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
     :param peer_fractions: for each peer, a target fraction of vector elements that this peer should average
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
       (the actual number of values by peer will be nearly proportional, but there are no exact guarantees)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
     :param modes: AveragingMode for each peer in ordered_peer_ids (normal, client-only or auxiliary)
-    :param gathered: additional user-defined data collected from this group
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
     :param sender_timeout: during all_reduce, any sender that fails to send tensor chunk within this many seconds from
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
       previous chunk will be marked as failed and excluded from averaging. default: equal to next_chunk_timeout
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
     :param reducer_timeout: during all_reduce, any reducer that fails to send results chunk within this many seconds
@@ -73,7 +72,6 @@ class AllReduceRunner(ServicerBase):
         ordered_peer_ids: Sequence[PeerID],
         ordered_peer_ids: Sequence[PeerID],
         peer_fractions: Tuple[float, ...],
         peer_fractions: Tuple[float, ...],
         modes: Optional[Sequence[AveragingMode]] = None,
         modes: Optional[Sequence[AveragingMode]] = None,
-        gathered: Optional[Dict[PeerID, Any]] = None,
         sender_timeout: Optional[float] = None,
         sender_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         reducer_timeout: Optional[float] = None,
         **kwargs,
         **kwargs,
@@ -99,7 +97,7 @@ class AllReduceRunner(ServicerBase):
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
             assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
 
 
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
         self.group_id, self.ordered_peer_ids = group_id, ordered_peer_ids
-        self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered
+        self.modes, self.peer_fractions = modes, peer_fractions
 
 
         if weight is None:
         if weight is None:
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)
             weight = float(modes[self.ordered_peer_ids.index(self.peer_id)] != AveragingMode.AUX)

+ 67 - 64
hivemind/averaging/averager.py

@@ -22,13 +22,7 @@ from hivemind.averaging.group_info import GroupInfo
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.load_balancing import load_balance_peers
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
 from hivemind.averaging.partition import DEFAULT_PART_SIZE_BYTES
-from hivemind.compression import (
-    CompressionBase,
-    CompressionInfo,
-    NoCompression,
-    deserialize_torch_tensor,
-    serialize_torch_tensor,
-)
+from hivemind.compression import CompressionBase, CompressionInfo, NoCompression, deserialize_torch_tensor
 from hivemind.dht import DHT, DHTID
 from hivemind.dht import DHT, DHTID
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, ServicerBase
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
 from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, DispatchFailure
@@ -36,7 +30,6 @@ from hivemind.proto import averaging_pb2
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils import MPFuture, TensorDescriptor, get_logger
 from hivemind.utils.asyncio import (
 from hivemind.utils.asyncio import (
     achain,
     achain,
-    afirst,
     aiter_with_timeout,
     aiter_with_timeout,
     anext,
     anext,
     as_aiter,
     as_aiter,
@@ -109,7 +102,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
     """
     """
 
 
     _matchmaking: Matchmaking
     _matchmaking: Matchmaking
-    _pending_group_assembled: asyncio.Event
+    _pending_groups_registered: asyncio.Event
     _state_updated: asyncio.Event
     _state_updated: asyncio.Event
     _p2p: P2P
     _p2p: P2P
     serializer = MSGPackSerializer
     serializer = MSGPackSerializer
@@ -207,7 +200,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             reducer_timeout=reducer_timeout,
             reducer_timeout=reducer_timeout,
         )
         )
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
         self._averaging_alpha, self._allreduce_timeout = averaging_alpha, allreduce_timeout
-        self._running_groups: Dict[GroupID, AllReduceRunner] = {}  # one or more assembled groups that run all-reduce
+        self._running_groups: Dict[GroupID, asyncio.Future[AllReduceRunner]] = {}
 
 
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)  # a control pipe used to communicate with daemon
 
 
@@ -309,8 +302,8 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     asyncio.create_task(self._declare_for_download_periodically())
                     asyncio.create_task(self._declare_for_download_periodically())
 
 
                 self._state_updated = asyncio.Event()
                 self._state_updated = asyncio.Event()
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
+                self._pending_groups_registered = asyncio.Event()
+                self._pending_groups_registered.set()
             except Exception as e:
             except Exception as e:
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 # Loglevel is DEBUG since normally the exception is propagated to the caller
                 logger.debug(e, exc_info=True)
                 logger.debug(e, exc_info=True)
@@ -441,7 +434,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             while not step.done():
             while not step.done():
                 try:
                 try:
-                    self._pending_group_assembled.clear()
+                    self._pending_groups_registered.clear()
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     step.stage = AveragingStage.LOOKING_FOR_GROUP
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
                     check_cancel_task = asyncio.create_task(step.wait_for_cancel())
@@ -458,17 +451,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     if group_info is None:
                     if group_info is None:
                         raise AllreduceException("Averaging step failed: could not find a group")
                         raise AllreduceException("Averaging step failed: could not find a group")
 
 
-                    step.stage = AveragingStage.RUNNING_ALLREDUCE
-
-                    step.set_result(
-                        await asyncio.wait_for(
-                            self._run_allreduce(
-                                group_info, tensor_infos=self.tensor_infos, weight=step.weight, **self.allreduce_kwargs
-                            ),
-                            timeout=self._allreduce_timeout,
+                    with self._register_allreduce_group(group_info):
+                        step.stage = AveragingStage.RUNNING_ALLREDUCE
+
+                        step.set_result(
+                            await asyncio.wait_for(
+                                self._aggregate_with_group(
+                                    group_info,
+                                    tensor_infos=self.tensor_infos,
+                                    weight=step.weight,
+                                    **self.allreduce_kwargs,
+                                ),
+                                timeout=self._allreduce_timeout,
+                            )
                         )
                         )
-                    )
-                    # averaging is finished, loop will now exit
+                        # averaging is finished, loop will now exit
 
 
                 except (
                 except (
                     AllreduceException,
                     AllreduceException,
@@ -503,8 +500,21 @@ class DecentralizedAverager(mp.Process, ServicerBase):
                     )
                     )
                 )
                 )
 
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
-        """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            self._running_groups[group_info.group_id] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            maybe_future = self._running_groups.pop(group_info.group_id, None)
+            if maybe_future is not None and not maybe_future.done():
+                logger.warning(f"All-reduce group {group_info.group_id} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
             user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
@@ -519,47 +529,39 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             )
             )
 
 
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
             async with enter_asynchronously(self.get_tensors()) as local_tensors:
-                allreduce = AllReduceRunner(
-                    p2p=self._p2p,
-                    servicer_type=type(self),
-                    prefix=self.prefix,
-                    group_id=group_info.group_id,
-                    tensors=local_tensors,
-                    ordered_peer_ids=group_info.peer_ids,
-                    peer_fractions=peer_fractions,
-                    gathered=user_gathered,
-                    modes=modes,
-                    **kwargs,
-                )
-
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        iter_results = allreduce.run()
-                        async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
-
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
-
-                return allreduce.gathered
+                await self._run_allreduce_inplace_(local_tensors, group_info, peer_fractions=peer_fractions, **kwargs)
+                return user_gathered
         except BaseException as e:
         except BaseException as e:
             if isinstance(e, Exception):
             if isinstance(e, Exception):
                 logger.exception(e)
                 logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
 
 
-    @contextlib.contextmanager
-    def register_allreduce_group(self, group_id: GroupID, allreduce: AllReduceRunner):
-        """registers a given all-reduce runner to listen for incoming connections"""
-        try:
-            self._running_groups[group_id] = allreduce
-            self._pending_group_assembled.set()
-            yield
-        finally:
-            self._running_groups.pop(group_id, None)
-            self._pending_group_assembled.set()
+    async def _run_allreduce_inplace_(
+        self, tensors: Sequence[torch.Tensor], group_info: GroupInfo, group_id: Optional[bytes] = None, **kwargs
+    ):
+        """Run one allreduce process to average tensors inplace. Can be called more than a few times in one aggregation process"""
+        group_id = group_info.group_id if group_id is None else group_id
+
+        runner = AllReduceRunner(
+            p2p=self._p2p,
+            servicer_type=type(self),
+            prefix=self.prefix,
+            tensors=tensors,
+            group_id=group_id,
+            ordered_peer_ids=group_info.peer_ids,
+            **kwargs,
+        )
+        assert group_id in self._running_groups, f"Group id {group_id} was not registered in _register_allreduce_group"
+        self._running_groups[group_id].set_result(runner)
+
+        if runner.modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+            async for tensor, update in azip(as_aiter(*tensors), runner):
+                tensor.add_(update, alpha=self._averaging_alpha)
+                self.last_updated = get_dht_time()
+                self._state_updated.set()
+        else:
+            async for _ in runner:
+                raise ValueError("aux peers should not receive averaged tensors")
 
 
     @contextlib.contextmanager
     @contextlib.contextmanager
     def get_tensors(self) -> Sequence[torch.Tensor]:
     def get_tensors(self) -> Sequence[torch.Tensor]:
@@ -586,13 +588,14 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         if request.group_id not in self._running_groups:
         if request.group_id not in self._running_groups:
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # this handles a special case when leader accepted us to group AND began allreduce right away,
             # but his response with group_id was delayed and other peers got to us first
             # but his response with group_id was delayed and other peers got to us first
-            await self._pending_group_assembled.wait()
+            await self._pending_groups_registered.wait()
 
 
-        group = self._running_groups.get(request.group_id)
-        if group is None:
+        future = self._running_groups.get(request.group_id)
+        if future is None:
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             yield averaging_pb2.AveragingData(code=averaging_pb2.BAD_GROUP_ID)
             return
             return
 
 
+        group = await future
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
         async for message in group.rpc_aggregate_part(achain(as_aiter(request), stream), context):
             yield message
             yield message
 
 

+ 20 - 7
hivemind/optim/grad_averager.py

@@ -1,16 +1,20 @@
 import contextlib
 import contextlib
-from typing import Iterable, Iterator, Optional
+from typing import Callable, Iterable, Iterator, Optional, Sequence, TypeVar
 
 
 import torch
 import torch
 
 
-import hivemind
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging import DecentralizedAverager
 from hivemind.averaging.control import StepControl
 from hivemind.averaging.control import StepControl
-from hivemind.utils import DHTExpiration, get_dht_time, get_logger
+from hivemind.dht import DHT
+from hivemind.utils import DHTExpiration, get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
+TGradientAverager = TypeVar("TGradientAverager", bound="GradientAverager")
+GradientAveragerFactory = Callable[..., TGradientAverager]
+
+
 class GradientAverager(DecentralizedAverager):
 class GradientAverager(DecentralizedAverager):
     """
     """
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
     An auxiliary averaging class that is responsible for accumulating gradients and aggregating them with peers.
@@ -36,6 +40,7 @@ class GradientAverager(DecentralizedAverager):
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       if True, the averager will only join existing groups where at least one peer has client_mode=False.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
       By default, this flag is copied from DHTNode inside the ``dht`` instance.
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
     :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
     :param kwargs: see DecentralizedAverager keyword arguments for additional parameters
 
 
 
 
@@ -69,12 +74,13 @@ class GradientAverager(DecentralizedAverager):
         self,
         self,
         parameters: Iterable[torch.nn.Parameter],
         parameters: Iterable[torch.nn.Parameter],
         *,
         *,
-        dht: hivemind.DHT,
+        dht: DHT,
         prefix: str,
         prefix: str,
         reuse_grad_buffers: bool = False,
         reuse_grad_buffers: bool = False,
         accumulate_grads_on: Optional[torch.device] = None,
         accumulate_grads_on: Optional[torch.device] = None,
         client_mode: bool = None,
         client_mode: bool = None,
         warn: bool = True,
         warn: bool = True,
+        averaged_grads: Sequence[torch.Tensor] = (),
         **kwargs,
         **kwargs,
     ):
     ):
         if reuse_grad_buffers and accumulate_grads_on is not None:
         if reuse_grad_buffers and accumulate_grads_on is not None:
@@ -95,9 +101,16 @@ class GradientAverager(DecentralizedAverager):
         self._new_averaged_grads = False
         self._new_averaged_grads = False
 
 
         with torch.no_grad():
         with torch.no_grad():
-            averaged_grads = tuple(
-                grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
-            )
+            if not averaged_grads:
+                averaged_grads = tuple(
+                    grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
+                )
+            else:
+                if all(
+                    params_grad.size() == grad.size()
+                    for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
+                ):
+                    raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
         super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)
 
 
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
     def _grads_from_parameters(self) -> Iterator[torch.Tensor]:

+ 15 - 5
hivemind/optim/optimizer.py

@@ -11,8 +11,9 @@ import torch
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.averaging.control import AveragingStage, StepControl
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.compression import CompressionBase, NoCompression
 from hivemind.dht import DHT
 from hivemind.dht import DHT
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.grad_scaler import GradScaler
 from hivemind.optim.grad_scaler import GradScaler
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
 from hivemind.optim.state_averager import (
 from hivemind.optim.state_averager import (
     LRSchedulerBase,
     LRSchedulerBase,
@@ -147,6 +148,7 @@ class Optimizer(torch.optim.Optimizer):
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
     :param auxiliary: if True, optimizer.step will only assist other peers in averaging (for cpu-only workers)
 
 
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
     :param grad_compression: compression strategy used for averaging gradients, default = no compression
+    :param grad_averager_factory: if provided, creates gradient averager with required averaging strategy
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param state_averaging_compression: compression for averaging params and state tensors, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param load_state_compression: compression strategy for loading state from peers, default = no compression
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
     :param average_opt_statistics: names of optimizer statistics from state dict that should be averaged with peers
@@ -187,6 +189,7 @@ class Optimizer(torch.optim.Optimizer):
         client_mode: bool = None,
         client_mode: bool = None,
         auxiliary: bool = False,
         auxiliary: bool = False,
         grad_compression: CompressionBase = NoCompression(),
         grad_compression: CompressionBase = NoCompression(),
+        grad_averager_factory: Optional[GradientAveragerFactory] = None,
         state_averaging_compression: CompressionBase = NoCompression(),
         state_averaging_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         load_state_compression: CompressionBase = NoCompression(),
         average_opt_statistics: Sequence[str] = (),
         average_opt_statistics: Sequence[str] = (),
@@ -201,7 +204,8 @@ class Optimizer(torch.optim.Optimizer):
 
 
         client_mode = client_mode if client_mode is None else dht.client_mode
         client_mode = client_mode if client_mode is None else dht.client_mode
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
         delay_optimizer_step = delay_optimizer_step if delay_optimizer_step is not None else delay_grad_averaging
-        offload_optimizer = offload_optimizer if offload_optimizer is not None else (params is not None)
+        if offload_optimizer is None:
+            offload_optimizer = params is not None and not use_local_updates
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         allreduce_timeout = allreduce_timeout if allreduce_timeout is not None else averaging_timeout
         next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
         next_chunk_timeout = next_chunk_timeout if next_chunk_timeout is not None else matchmaking_time
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
         assert not delay_grad_averaging or delay_optimizer_step, "delay_grad_averaging requires delay_optimizer_step"
@@ -225,6 +229,9 @@ class Optimizer(torch.optim.Optimizer):
         if use_local_updates:
         if use_local_updates:
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not reuse_grad_buffers, "if local_updates is True, gradients will not be accumulated"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
             assert not delay_grad_averaging, "if local_updates is True, gradients will not be averaged"
+            assert (
+                grad_averager_factory is None
+            ), "if local_updates is True, provided grad_averager_factory will not be used"
 
 
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.dht, self.run_id, self.client_mode, self.auxiliary = dht, run_id, client_mode, auxiliary
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
         self.batch_size_per_step, self.target_batch_size = batch_size_per_step, target_batch_size
@@ -257,7 +264,7 @@ class Optimizer(torch.optim.Optimizer):
         )
         )
         if not use_local_updates:
         if not use_local_updates:
             self.grad_averager = self._make_gradient_averager(
             self.grad_averager = self._make_gradient_averager(
-                reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression, **averager_opts or {}
+                grad_averager_factory, reuse_grad_buffers=reuse_grad_buffers, compression=grad_compression
             )
             )
         else:
         else:
             self.grad_averager = None
             self.grad_averager = None
@@ -290,9 +297,10 @@ class Optimizer(torch.optim.Optimizer):
             **kwargs,
             **kwargs,
         )
         )
 
 
-    def _make_gradient_averager(self, **kwargs) -> GradientAverager:
+    def _make_gradient_averager(self, factory: Optional[GradientAveragerFactory], **kwargs) -> GradientAverager:
         assert hasattr(self, "state_averager"), "must initialize state averager first"
         assert hasattr(self, "state_averager"), "must initialize state averager first"
-        grad_averager = GradientAverager(
+        factory = factory if factory is not None else GradientAverager
+        grad_averager = factory(
             dht=self.dht,
             dht=self.dht,
             prefix=f"{self.run_id}_grad_averager",
             prefix=f"{self.run_id}_grad_averager",
             parameters=self.state_averager.main_parameters,
             parameters=self.state_averager.main_parameters,
@@ -684,6 +692,8 @@ class Optimizer(torch.optim.Optimizer):
             while True:
             while True:
                 try:
                 try:
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     self.state_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
+                    if self.grad_averager is not None:
+                        self.grad_averager.load_state_from_peers(timeout=self.load_state_timeout, **kwargs)
                     break
                     break
                 except KeyboardInterrupt:
                 except KeyboardInterrupt:
                     raise
                     raise

+ 223 - 0
hivemind/optim/power_sgd_averager.py

@@ -0,0 +1,223 @@
+import asyncio
+import contextlib
+import multiprocessing as mp
+from enum import Enum
+from typing import Any, Iterable, Optional, Sequence
+
+import torch
+
+from hivemind.averaging.allreduce import AveragingMode
+from hivemind.averaging.group_info import GroupInfo
+from hivemind.averaging.load_balancing import load_balance_peers
+from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
+from hivemind.compression import CompressionInfo, TensorRole
+from hivemind.dht import DHT
+from hivemind.optim.grad_averager import GradientAverager
+from hivemind.utils import get_logger
+from hivemind.utils.asyncio import enter_asynchronously
+from hivemind.utils.math import get_flatten_greedy_dims, orthogonalize_
+
+GatheredData = Any
+logger = get_logger(__name__)
+
+
+class AllReducePhases(Enum):
+    PHASE_P = 1
+    PHASE_Q = 2
+
+
+class PowerSGDGradientAverager(GradientAverager):
+    """
+    A gradient averager that implements PowerSGD compression: https://arxiv.org/abs/1905.13727
+    For basic properties and guaranties of gradient averagers, please refer to the base class docstring.
+    Put simply, this method approximates large gradient tensors (m,n) with a product of two
+    smaller matrices (m,r) by (r,n), where r is a parameter chosen by the user (see averager_rank).
+
+    As a result, PowerSGD only needs to aggregate O((m + n) * r) tensors instead of O(m * n).
+    High r, e.g. sqrt(max(m, n)) typically reduce communication by 2-8x without affecting convergence.
+    Low r, e.g. 1-8, further accelerate communication, but may converge worse depending on the task.
+
+    To maintain convergence with low r, this averager uses the error feedback strategy. Put simply,
+    if some part of the gradient is "lost in compression", it will be added to the next iteration.
+    This has two implications: (a) it needs more RAM in order to store the "feedback buffers"
+    and (b) if devices stay alive only for one step, training with small rank may converge slower.
+    This is because error feedback takes multiple steps to kick in.
+
+    Since not all gradients are matrices, PowerSGD views 3d+ tensors via tensor.flatten(1, -1).
+    If a tensor has less than 2 dimensions or does not compress efficiently, it will be aggregated
+    normally, i.e. without powerSGD. See min_compression_ratio for details.
+
+    :note: due to the above rule, PowerSGD is *not* shape-invariant. For instance, a
+     matrix of shape (256, 256) be compressed differently if you .reshape it to (32, 32, 32).
+
+    :param parameters: pytorch parameters for which to aggregate gradients
+    :param averager_rank: rank of compressed gradients
+    :param dht: a DHT isntance connected to the rest of the swarm. See hivemind.DHT docs
+    :param prefix: a unique DHT key used for matchmaking. E.g. this can be your experiment name with optional suffixes
+    :param reuse_grad_buffers: if True, use model's .grad buffers for accumulating gradients over multiple steps.
+      This is more memory efficient, but it requires that the user does *not* call zero_grad or clip_by_whatever at all
+    :param accumulate_grads_on: if specified, accumulate gradients on this device. By default, this will use the same
+      device as model parameters. One can specify a different device (e.g. 'cpu' vs 'cuda') to save device memory at
+      the cost of extra time per step. If reuse_grad_buffers is True, this parameter has no effect.
+    :param client_mode: if False, this averager will accept incoming requests from other peers.
+      if True, the averager will only join existing groups where at least one peer has client_mode=False.
+      By default, this flag is copied from DHTNode inside the ``dht`` instance.
+    :param warn: if True, warn when the averager did not reset accumulators after use or did not use averaging results
+    :param min_compression_ratio: apply PowerSGD to a tensor only if it reduces communication by at least this factor,
+      otherwise aggregate tensors as is
+    :param averaged_grads: if provided, it will be used as a set of averagable gradients
+    """
+
+    def __init__(
+        self,
+        parameters: Iterable[torch.nn.Parameter],
+        averager_rank: int,
+        *,
+        dht: DHT,
+        prefix: str,
+        reuse_grad_buffers: bool = False,
+        accumulate_grads_on: Optional[torch.device] = None,
+        client_mode: bool = None,
+        warn: bool = True,
+        min_compression_ratio: float = 0.5,
+        averaged_grads: Optional[Sequence[torch.Tensor]] = None,
+        **kwargs,
+    ):
+        self.rank = averager_rank
+        self.parameters = tuple(parameters)
+        self._uncompressed_gradients_indexes = set(
+            i
+            for i, grad in enumerate(self._grads_from_parameters())
+            if grad.ndim <= 1
+            or (1 - self.rank * sum(get_flatten_greedy_dims(grad)) / grad.numel()) < min_compression_ratio
+            # compute how much parameters are left after factorization
+        )
+        self._ms = [
+            torch.zeros_like(grad, device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+        self._qs = [
+            torch.rand((get_flatten_greedy_dims(grad)[1], self.rank), device="cpu").share_memory_()
+            for idx, grad in enumerate(self._grads_from_parameters())
+            if idx not in self._uncompressed_gradients_indexes
+        ]
+
+        super().__init__(
+            self.parameters,
+            dht=dht,
+            prefix=prefix,
+            reuse_grad_buffers=reuse_grad_buffers,
+            accumulate_grads_on=accumulate_grads_on,
+            client_mode=client_mode,
+            warn=warn,
+            averaged_grads=averaged_grads,
+            **kwargs,
+        )
+
+    @contextlib.contextmanager
+    def _register_allreduce_group(self, group_info: GroupInfo):
+        """Register a given group for one or more all-reduce rounds"""
+        try:
+            for phase in list(AllReducePhases):
+                self._running_groups[group_info.group_id + phase.name.encode()] = asyncio.Future()
+            self._pending_groups_registered.set()
+            yield
+        finally:
+            for phase in list(AllReducePhases):
+                maybe_future = self._running_groups.pop(group_info.group_id + phase.name.encode(), None)
+                if maybe_future and not maybe_future.done():
+                    logger.warning(f"All-reduce group {group_info.group_id + phase.name.encode()} did not finish.")
+            self._pending_groups_registered.set()
+
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+        """Run aggregation in a given group and update tensors in place, return gathered metadata"""
+        try:
+            bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
+            user_gathered = dict(zip(group_info.peer_ids, map(self.serializer.loads, user_gathered_bytes)))
+            modes = tuple(map(AveragingMode, mode_ids))
+
+            download_bandwidths = [
+                thr if mode != AveragingMode.CLIENT else 0.0 for thr, mode in zip(bandwidths, modes)
+            ]
+            peer_fractions = await asyncio.get_event_loop().run_in_executor(
+                None, load_balance_peers, self.total_size, download_bandwidths, min_vector_size
+            )
+
+            async with enter_asynchronously(self.get_tensors()) as averaged_grads:
+                averaged_grads_via_sgd = [
+                    grad for idx, grad in enumerate(averaged_grads) if idx not in self._uncompressed_gradients_indexes
+                ]
+                for grad, m in zip(averaged_grads_via_sgd, self._ms):
+                    m.add_(grad.to(m.device))
+
+                ps = [
+                    torch.zeros((get_flatten_greedy_dims(grad)[0], self.rank), device="cpu")
+                    for idx, grad in enumerate(averaged_grads_via_sgd)
+                ]
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    # we use reshape for all matrixes because PowerSGD works only with 2d tensors
+                    torch.matmul(m.reshape(-1, q.size(0)), q, out=p)
+
+                p_group_id = group_info.group_id + AllReducePhases.PHASE_P.name.encode()
+                q_groud_id = group_info.group_id + AllReducePhases.PHASE_Q.name.encode()
+
+                await self._run_allreduce_inplace_(ps, group_info, p_group_id, peer_fractions=peer_fractions, **kwargs)
+
+                for p in ps:
+                    orthogonalize_(p)
+
+                for p, q, m in zip(ps, self._qs, self._ms):
+                    torch.matmul(m.reshape(-1, q.size(0)).t(), p, out=q)
+
+                phase_q_tensors = self._qs + [
+                    grad for idx, grad in enumerate(averaged_grads) if idx in self._uncompressed_gradients_indexes
+                ]
+
+                await self._run_allreduce_inplace_(
+                    phase_q_tensors, group_info, q_groud_id, peer_fractions=peer_fractions, **kwargs
+                )
+
+                for p, q, m, grad in zip(ps, self._qs, self._ms, averaged_grads_via_sgd):
+                    new_m = torch.matmul(p, q.t()).reshape(m.size())
+                    m.sub_(new_m)
+                    grad.copy_(new_m)
+
+                return user_gathered
+        except BaseException as e:
+            logger.exception(e)
+            raise MatchmakingException(f"Unable to run All-Reduce: {e}")
+
+    def get_current_state(self):
+        """
+        Get current gradient averager state and when requested by a newbie peer.
+        """
+        with torch.no_grad(), self.lock_averaged_tensors:
+            grad_averager_buffers = [q for q in self._qs]
+            grad_averager_buffers_infos = [
+                CompressionInfo.from_tensor(buffer, key=f"buffer_q_{key}", role=TensorRole.GRADIENT)
+                for buffer, key in zip(grad_averager_buffers, enumerate(grad_averager_buffers))
+            ]
+
+        metadata = dict(group_bits=self.get_group_bits())
+        return metadata, grad_averager_buffers, grad_averager_buffers_infos
+
+    def load_state_from_peers(self, **kwargs):
+        """
+        Attempt to download the latest optimizer state from peers and update gradient averager buffers.
+        :returns: whether or the averager succeeded in loading parameters
+        """
+        loaded_state = super().load_state_from_peers(**kwargs)
+        if loaded_state is None:
+            return
+
+        metadata, flat_tensors = loaded_state
+        logger.info("Starting loading gradient averager buffers from peers")
+
+        if len(flat_tensors) != len(self._qs):
+            logger.error("Failed to load state from peer, received invalid parameters, extras or metadata")
+            return
+
+        with torch.no_grad(), self.lock_averaged_tensors:
+            for local_q, loaded_q in zip(self._qs, flat_tensors):
+                local_q.copy_(loaded_q, non_blocking=True)

+ 38 - 12
hivemind/p2p/p2p_daemon.py

@@ -3,6 +3,7 @@ import json
 import logging
 import logging
 import os
 import os
 import secrets
 import secrets
+import warnings
 from collections.abc import AsyncIterable as AsyncIterableABC
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
@@ -17,8 +18,10 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.proto import crypto_pb2
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
 from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -89,16 +92,16 @@ class P2P:
         identity_path: Optional[str] = None,
         identity_path: Optional[str] = None,
         idle_timeout: float = 30,
         idle_timeout: float = 30,
         nat_port_map: bool = True,
         nat_port_map: bool = True,
-        quic: bool = False,
         relay_hop_limit: int = 0,
         relay_hop_limit: int = 0,
         startup_timeout: float = 15,
         startup_timeout: float = 15,
         tls: bool = True,
         tls: bool = True,
         use_auto_relay: bool = False,
         use_auto_relay: bool = False,
         use_ipfs: bool = False,
         use_ipfs: bool = False,
         use_relay: bool = True,
         use_relay: bool = True,
-        use_relay_hop: bool = False,
-        use_relay_discovery: bool = False,
         persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
         persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+        quic: Optional[bool] = None,
+        use_relay_hop: Optional[bool] = None,
+        use_relay_discovery: Optional[bool] = None,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -112,20 +115,20 @@ class P2P:
                          Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
                          Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
         :param force_reachability: Force reachability mode (public/private)
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
-        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
-                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param identity_path: Path to a private key file. If defined, makes the peer ID deterministic.
+                              If the file does not exist yet, writes a new private key to this file.
         :param idle_timeout: kill daemon if client has been idle for a given number of
         :param idle_timeout: kill daemon if client has been idle for a given number of
                              seconds before opening persistent streams
                              seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
         :param nat_port_map: Enables NAT port mapping
-        :param quic: Enables the QUIC transport
         :param relay_hop_limit: sets the hop limit for hop relays
         :param relay_hop_limit: sets the hop limit for hop relays
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param tls: Enables TLS1.3 channel security protocol
         :param tls: Enables TLS1.3 channel security protocol
         :param use_auto_relay: enables autorelay
         :param use_auto_relay: enables autorelay
         :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
         :param use_relay: enables circuit relay
         :param use_relay: enables circuit relay
-        :param use_relay_hop: enables hop for relay
-        :param use_relay_discovery: enables passive discovery for relay
+        :param quic: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_hop: Deprecated, has no effect since libp2p 0.17.0
+        :param use_relay_discovery: Deprecated, has no effect since libp2p 0.17.0
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -133,6 +136,14 @@ class P2P:
             initial_peers and use_ipfs
             initial_peers and use_ipfs
         ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
         ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
 
 
+        if not all(arg is None for arg in [quic, use_relay_hop, use_relay_discovery]):
+            warnings.warn(
+                "Parameters `quic`, `use_relay_hop`, and `use_relay_discovery` of hivemind.P2P "
+                "have no effect since libp2p 0.17.0 and will be removed in hivemind 1.2.0+",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
         self = cls()
         self = cls()
         with path(cli, P2PD_FILENAME) as p:
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
             p2pd_path = p
@@ -147,7 +158,7 @@ class P2P:
                     raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
                     raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
 
         need_bootstrap = bool(initial_peers) or use_ipfs
         need_bootstrap = bool(initial_peers) or use_ipfs
-        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
+        process_kwargs = cls.DHT_MODE_MAPPING[dht_mode].copy()
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
         for param, value in [
         for param, value in [
             ("bootstrapPeers", initial_peers),
             ("bootstrapPeers", initial_peers),
@@ -156,7 +167,11 @@ class P2P:
         ]:
         ]:
             if value:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
                 process_kwargs[param] = self._maddrs_to_str(value)
+
         if identity_path is not None:
         if identity_path is not None:
+            if not os.path.isfile(identity_path):
+                logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+                self.generate_identity(identity_path)
             process_kwargs["id"] = identity_path
             process_kwargs["id"] = identity_path
 
 
         proc_args = self._make_process_args(
         proc_args = self._make_process_args(
@@ -168,10 +183,7 @@ class P2P:
             idleTimeout=f"{idle_timeout}s",
             idleTimeout=f"{idle_timeout}s",
             listen=self._daemon_listen_maddr,
             listen=self._daemon_listen_maddr,
             natPortMap=nat_port_map,
             natPortMap=nat_port_map,
-            quic=quic,
             relay=use_relay,
             relay=use_relay,
-            relayDiscovery=use_relay_discovery,
-            relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
             tls=tls,
             tls=tls,
             persistentConnMaxMsgSize=persistent_conn_max_msg_size,
             persistentConnMaxMsgSize=persistent_conn_max_msg_size,
@@ -205,6 +217,20 @@ class P2P:
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
 
 
+    @staticmethod
+    def generate_identity(identity_path: str) -> None:
+        private_key = RSAPrivateKey()
+        protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes())
+
+        try:
+            with open(identity_path, "wb") as f:
+                f.write(protobuf.SerializeToString())
+        except FileNotFoundError:
+            raise FileNotFoundError(
+                f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist"
+            )
+        os.chmod(identity_path, 0o400)
+
     @classmethod
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
         """

+ 24 - 0
hivemind/proto/crypto.proto

@@ -0,0 +1,24 @@
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package crypto.pb;
+
+enum KeyType {
+  RSA = 0;
+  Ed25519 = 1;
+  Secp256k1 = 2;
+  ECDSA = 3;
+}
+
+message PublicKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}
+
+message PrivateKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}

+ 3 - 3
hivemind/proto/p2pd.proto

@@ -1,6 +1,6 @@
-//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
-//Licence: MIT
-//Author: Kevin Mai-Husan Chia
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
 
 
 syntax = "proto2";
 syntax = "proto2";
 
 

+ 9 - 6
hivemind/utils/crypto.py

@@ -60,19 +60,22 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
         return RSAPublicKey(self._private_key.public_key())
 
 
+    def to_bytes(self) -> bytes:
+        return self._private_key.private_bytes(
+            encoding=serialization.Encoding.DER,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption(),
+        )
+
     def __getstate__(self):
     def __getstate__(self):
         state = self.__dict__.copy()
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
         # Serializes the private key to make the class instances picklable
-        state["_private_key"] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption(),
-        )
+        state["_private_key"] = self.to_bytes()
         return state
         return state
 
 
     def __setstate__(self, state):
     def __setstate__(self, state):
         self.__dict__.update(state)
         self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._private_key = serialization.load_der_private_key(self._private_key, password=None)
 
 
 
 
 class RSAPublicKey(PublicKey):
 class RSAPublicKey(PublicKey):

+ 24 - 0
hivemind/utils/math.py

@@ -0,0 +1,24 @@
+import torch
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def orthogonalize_(matrix, eps: float = 1e-8):
+    """Orthogonalize a 2d tensor in-place over the last dimension"""
+    n, m = matrix.shape
+    for i in range(m):
+        col = matrix[:, i]
+        F.normalize(col, dim=0, eps=eps, out=col)
+        if i + 1 < m:
+            rest = matrix[:, i + 1 :]
+            rest.addmm_(col[:, None], (col @ rest)[None, :], alpha=-1)
+
+
+def get_flatten_greedy_dims(tensor: torch.Tensor, max_ndim: int = 2):
+    """get dims to flatten tensor upto max_ndim dimensions by merging small axes together"""
+    dims = list(tensor.shape)
+    while len(dims) > max_ndim:
+        squeeze_ix = min(range(len(dims) - 1), key=lambda i: dims[i] * dims[i + 1])
+        squeezed_dim = dims.pop(squeeze_ix)
+        dims[squeeze_ix] *= squeezed_dim
+    return dims

+ 1 - 1
pyproject.toml

@@ -1,6 +1,6 @@
 [tool.black]
 [tool.black]
 line-length = 119
 line-length = 119
-required-version = "22.1.0"
+required-version = "22.3.0"
 
 
 [tool.isort]
 [tool.isort]
 profile = "black"
 profile = "black"

+ 1 - 1
requirements-dev.txt

@@ -6,6 +6,6 @@ coverage==6.0.2  # see https://github.com/pytest-dev/pytest-cov/issues/520
 tqdm
 tqdm
 scikit-learn
 scikit-learn
 torchvision
 torchvision
-black==22.1.0
+black==22.3.0
 isort==5.10.1
 isort==5.10.1
 psutil
 psutil

+ 37 - 29
setup.py

@@ -3,7 +3,6 @@ import glob
 import hashlib
 import hashlib
 import os
 import os
 import re
 import re
-import shlex
 import subprocess
 import subprocess
 import tarfile
 import tarfile
 import tempfile
 import tempfile
@@ -14,20 +13,25 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 from setuptools.command.develop import develop
 
 
-P2PD_VERSION = "v0.3.6"
-P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
-LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
-P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
+P2PD_VERSION = "v0.3.8"
+
+P2PD_SOURCE_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/"
+
+# The value is sha256 of the binary from the release page
+EXECUTABLES = {
+    "p2pd": "785058526d993f699c674dc2f9b66d565a52315a18b79b629998fab3ebd8e20f",
+}
+
 
 
 here = os.path.abspath(os.path.dirname(__file__))
 here = os.path.abspath(os.path.dirname(__file__))
 
 
 
 
-def md5(fname, chunk_size=4096):
-    hash_md5 = hashlib.md5()
-    with open(fname, "rb") as f:
-        for chunk in iter(lambda: f.read(chunk_size), b""):
-            hash_md5.update(chunk)
-    return hash_md5.hexdigest()
+def sha256(path):
+    if not os.path.exists(path):
+        return None
+    with open(path, "rb") as f:
+        return hashlib.sha256(f.read()).hexdigest()
 
 
 
 
 def proto_compile(output_path):
 def proto_compile(output_path):
@@ -64,32 +68,36 @@ def build_p2p_daemon():
 
 
     with tempfile.TemporaryDirectory() as tempdir:
     with tempfile.TemporaryDirectory() as tempdir:
         dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
         dest = os.path.join(tempdir, "libp2p-daemon.tar.gz")
-        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+        urllib.request.urlretrieve(P2PD_SOURCE_URL, dest)
 
 
         with tarfile.open(dest, "r:gz") as tar:
         with tarfile.open(dest, "r:gz") as tar:
             tar.extractall(tempdir)
             tar.extractall(tempdir)
 
 
-        result = subprocess.run(
-            f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
-            cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION[1:]}", "p2pd"),
-            shell=True,
-        )
-
-        if result.returncode:
-            raise RuntimeError(
-                "Failed to build or install libp2p-daemon:" f" exited with status code: {result.returncode}"
+        for executable in EXECUTABLES:
+            result = subprocess.run(
+                ["go", "build", "-o", os.path.join(here, "hivemind", "hivemind_cli", executable)],
+                cwd=os.path.join(tempdir, f"go-libp2p-daemon-{P2PD_VERSION.lstrip('v')}", executable),
             )
             )
+            if result.returncode != 0:
+                raise RuntimeError(f"Failed to build {executable}: exited with status code: {result.returncode}")
 
 
 
 
 def download_p2p_daemon():
 def download_p2p_daemon():
-    install_path = os.path.join(here, "hivemind", "hivemind_cli")
-    binary_path = os.path.join(install_path, "p2pd")
-    if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
-        print("Downloading Peer to Peer Daemon")
-        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
-        os.chmod(binary_path, 0o777)
-        if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
+    for executable, expected_hash in EXECUTABLES.items():
+        binary_path = os.path.join(here, "hivemind", "hivemind_cli", executable)
+
+        if sha256(binary_path) != expected_hash:
+            binary_url = os.path.join(P2PD_BINARY_URL, executable)
+            print(f"Downloading {binary_url}")
+
+            urllib.request.urlretrieve(binary_url, binary_path)
+            os.chmod(binary_path, 0o777)
+
+            actual_hash = sha256(binary_path)
+            if actual_hash != expected_hash:
+                raise RuntimeError(
+                    f"The sha256 checksum for {executable} does not match (expected: {expected_hash}, actual: {actual_hash})"
+                )
 
 
 
 
 class BuildPy(build_py):
 class BuildPy(build_py):

+ 14 - 12
tests/test_allreduce_fault_tolerance.py

@@ -35,7 +35,7 @@ class FaultyAverager(hivemind.DecentralizedAverager):
         self.fault = fault
         self.fault = fault
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
-    async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
+    async def _aggregate_with_group(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         """Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
         try:
         try:
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
             bandwidths, mode_ids, user_gathered_bytes = zip(*map(self.serializer.loads, group_info.gathered))
@@ -60,24 +60,26 @@ class FaultyAverager(hivemind.DecentralizedAverager):
                     tensors=local_tensors,
                     tensors=local_tensors,
                     ordered_peer_ids=group_info.peer_ids,
                     ordered_peer_ids=group_info.peer_ids,
                     peer_fractions=peer_fractions,
                     peer_fractions=peer_fractions,
-                    gathered=user_gathered,
                     modes=modes,
                     modes=modes,
                     fault=self.fault,
                     fault=self.fault,
                     **kwargs,
                     **kwargs,
                 )
                 )
 
 
-                with self.register_allreduce_group(group_info.group_id, allreduce):
-                    if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
-                        async for tensor, update in azip(as_aiter(*local_tensors), allreduce):
-                            # all-reduce is performed asynchronously while iterating
-                            tensor.add_(update, alpha=self._averaging_alpha)
-                        self._state_updated.set()
+                self._running_groups[group_info.group_id].set_result(allreduce)
+                # TODO maybe this can be extracted into a method that checks if register_... context is active.
 
 
-                    else:
-                        async for _ in allreduce:  # trigger all-reduce by iterating
-                            raise ValueError("aux peers should not receive averaged tensors")
+                if modes[group_info.peer_ids.index(self.peer_id)] != AveragingMode.AUX:
+                    iter_results = allreduce.run()
+                    async for tensor, update in azip(as_aiter(*local_tensors), iter_results):
+                        # all-reduce is performed asynchronously while iterating
+                        tensor.add_(update, alpha=self._averaging_alpha)
+                    self._state_updated.set()
+
+                else:
+                    async for _ in allreduce:  # trigger all-reduce by iterating
+                        raise ValueError("aux peers should not receive averaged tensors")
 
 
-                return allreduce.gathered
+                return user_gathered
         except BaseException as e:
         except BaseException as e:
             logger.exception(e)
             logger.exception(e)
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")
             raise MatchmakingException(f"Unable to run All-Reduce: {e}")

+ 62 - 12
tests/test_optimizer.py

@@ -11,24 +11,31 @@ import torch.nn.functional as F
 
 
 import hivemind
 import hivemind
 from hivemind.averaging.control import AveragingStage
 from hivemind.averaging.control import AveragingStage
-from hivemind.optim.grad_averager import GradientAverager
+from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
 from hivemind.optim.optimizer import Optimizer
 from hivemind.optim.optimizer import Optimizer
+from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.progress_tracker import ProgressTracker
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.optim.state_averager import TrainingStateAverager
 from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.crypto import RSAPrivateKey
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
-def test_grad_averager():
+@pytest.mark.parametrize(
+    "grad_averager_factory",
+    [GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
+)
+def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
+    parameter_shape = (5, 5)
+
     dht1 = hivemind.DHT(start=True)
     dht1 = hivemind.DHT(start=True)
-    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager1 = GradientAverager(
+    model1 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager1 = grad_averager_factory(
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
         model1.parameters(), dht=dht1, prefix="test", target_group_size=2, reuse_grad_buffers=False, start=True
     )
     )
 
 
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
     dht2 = hivemind.DHT(start=True, initial_peers=dht1.get_visible_maddrs())
-    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(3))})
-    averager2 = GradientAverager(
+    model2 = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
+    averager2 = grad_averager_factory(
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
         model2.parameters(), dht=dht2, prefix="test", target_group_size=2, reuse_grad_buffers=True, start=True
     )
     )
 
 
@@ -38,12 +45,12 @@ def test_grad_averager():
     for i in range(10):
     for i in range(10):
         time.sleep(0.1)
         time.sleep(0.1)
         if i % 3 == 0:
         if i % 3 == 0:
-            loss1 = F.mse_loss(model1.w, torch.ones(3))
+            loss1 = F.mse_loss(model1.w, torch.ones(parameter_shape))
             loss1.backward()
             loss1.backward()
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             averager1.accumulate_grads_(batch_size=2)  # total: 4 times * 2 samples = 8
             model1.zero_grad()
             model1.zero_grad()
         else:
         else:
-            loss2 = F.mse_loss(model2.w, -torch.ones(3))
+            loss2 = F.mse_loss(model2.w, -torch.ones(parameter_shape))
             loss2.backward()
             loss2.backward()
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             averager2.accumulate_grads_(batch_size=3)  # total: 6 times * 3 samples = 18
             # note: we do not call zero grad here because reuse_grad_buffers=True
             # note: we do not call zero grad here because reuse_grad_buffers=True
@@ -51,11 +58,11 @@ def test_grad_averager():
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     assert control1.stage == control2.stage == AveragingStage.AWAITING_TRIGGER
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     peer1_samples, peer1_times, peer2_samples, peer2_times = 8, 4, 18, 6
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
     assert averager1.local_samples_accumulated == peer1_samples and averager1.local_times_accumulated == peer1_times
-    ref_grads1 = torch.full((3,), -2 * 1 / 3 * averager1.local_times_accumulated)
+    ref_grads1 = torch.full(parameter_shape, -2 / np.prod(parameter_shape) * averager1.local_times_accumulated)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
     assert torch.allclose(next(averager1._grad_accumulators()), ref_grads1)
 
 
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
     assert averager2.local_samples_accumulated == peer2_samples and averager2.local_times_accumulated == peer2_times
-    ref_grads2 = torch.full((3,), 2 * 1 / 3 * averager2.local_times_accumulated)
+    ref_grads2 = torch.full(parameter_shape, 2 / np.prod(parameter_shape) * averager2.local_times_accumulated)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
     assert torch.allclose(next(averager2._grad_accumulators()), ref_grads2)
 
 
     averager1.step(control=control1, wait=False)
     averager1.step(control=control1, wait=False)
@@ -162,7 +169,11 @@ def test_load_state_from_peers():
     )
     )
 
 
     avgr1 = TrainingStateAverager(
     avgr1 = TrainingStateAverager(
-        dht=dht1, params=model1.parameters(), allow_state_sharing=False, start=True, **common_kwargs
+        dht=dht1,
+        params=model1.parameters(),
+        allow_state_sharing=False,
+        start=True,
+        **common_kwargs,
     )
     )
 
 
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
     avgr2 = TrainingStateAverager(dht=dht2, params=model2.parameters(), start=True, **common_kwargs)
@@ -286,12 +297,45 @@ def test_progress_tracker():
 
 
 
 
 @pytest.mark.forked
 @pytest.mark.forked
+@pytest.mark.parametrize(
+    "use_local_updates, delay_state_averaging, delay_optimizer_step, delay_grad_averaging, reuse_grad_buffers",
+    # fmt: off
+    [
+        (False, False, False, False, False),
+        (False, True, False, False, False),
+        (False, True, True, True, False),
+        (False, False, False, False, True),
+        (False, True, True, True, True),
+        (False, True, True, False, True),
+        (True, False, False, False, False),
+        (True, True, False, False, False,),
+    ],
+    # fmt: on
+)
 def test_optimizer(
 def test_optimizer(
+    use_local_updates: bool,
+    delay_state_averaging: bool,
+    delay_optimizer_step: bool,
+    delay_grad_averaging: bool,
+    reuse_grad_buffers: bool,
+):
+    _test_optimizer(
+        use_local_updates=use_local_updates,
+        delay_state_averaging=delay_state_averaging,
+        delay_grad_averaging=delay_grad_averaging,
+        delay_optimizer_step=delay_optimizer_step,
+        reuse_grad_buffers=reuse_grad_buffers,
+    )
+
+
+def _test_optimizer(
     num_peers: int = 1,
     num_peers: int = 1,
     num_clients: int = 0,
     num_clients: int = 0,
     target_batch_size: int = 32,
     target_batch_size: int = 32,
     total_epochs: int = 3,
     total_epochs: int = 3,
+    use_local_updates: bool = False,
     reuse_grad_buffers: bool = True,
     reuse_grad_buffers: bool = True,
+    delay_state_averaging: bool = True,
     delay_grad_averaging: bool = True,
     delay_grad_averaging: bool = True,
     delay_optimizer_step: bool = True,
     delay_optimizer_step: bool = True,
     average_state_every: int = 1,
     average_state_every: int = 1,
@@ -319,9 +363,11 @@ def test_optimizer(
             dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
             dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
             tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
             tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=1.0),
             averager_opts=dict(request_timeout=0.5),
             averager_opts=dict(request_timeout=0.5),
+            use_local_updates=use_local_updates,
             matchmaking_time=1.0,
             matchmaking_time=1.0,
             averaging_timeout=5.0,
             averaging_timeout=5.0,
             reuse_grad_buffers=reuse_grad_buffers,
             reuse_grad_buffers=reuse_grad_buffers,
+            delay_state_averaging=delay_state_averaging,
             delay_grad_averaging=delay_grad_averaging,
             delay_grad_averaging=delay_grad_averaging,
             delay_optimizer_step=delay_optimizer_step,
             delay_optimizer_step=delay_optimizer_step,
             average_state_every=average_state_every,
             average_state_every=average_state_every,
@@ -380,6 +426,10 @@ def test_optimizer(
     assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
     assert 4 / 0.3 * 0.8 <= optimizer.tracker.performance_ema.samples_per_second <= 4 / 0.3 * 1.2
 
 
     assert not optimizer.state_averager.is_alive()
     assert not optimizer.state_averager.is_alive()
-    assert not optimizer.grad_averager.is_alive()
     assert not optimizer.tracker.is_alive()
     assert not optimizer.tracker.is_alive()
+    if not use_local_updates:
+        assert not optimizer.grad_averager.is_alive()
+    else:
+        assert optimizer.grad_averager is None
+
     assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()
     assert optimizer.scheduled_grads is None or optimizer.scheduled_grads.done()

+ 29 - 2
tests/test_p2p_daemon.py

@@ -1,6 +1,8 @@
 import asyncio
 import asyncio
 import multiprocessing as mp
 import multiprocessing as mp
+import os
 import subprocess
 import subprocess
+import tempfile
 from contextlib import closing
 from contextlib import closing
 from functools import partial
 from functools import partial
 from typing import List
 from typing import List
@@ -45,6 +47,31 @@ async def test_startup_error_message():
         await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
         await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
 
 
+@pytest.mark.asyncio
+async def test_identity():
+    with tempfile.TemporaryDirectory() as tempdir:
+        id1_path = os.path.join(tempdir, "id1")
+        id2_path = os.path.join(tempdir, "id2")
+        p2ps = await asyncio.gather(*[P2P.create(identity_path=path) for path in [None, None, id1_path, id2_path]])
+
+        # We create the second daemon with id2 separately
+        # to avoid a race condition while saving a newly generated identity
+        p2ps.append(await P2P.create(identity_path=id2_path))
+
+        # Using the same identity (if any) should lead to the same peer ID
+        assert p2ps[-2].peer_id == p2ps[-1].peer_id
+
+        # The rest of peer IDs should be different
+        peer_ids = {instance.peer_id for instance in p2ps}
+        assert len(peer_ids) == 4
+
+        for instance in p2ps:
+            await instance.shutdown()
+
+    with pytest.raises(FileNotFoundError, match=r"The directory.+does not exist"):
+        P2P.generate_identity(id1_path)
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "host_maddrs",
     "host_maddrs",
     [
     [
@@ -55,11 +82,11 @@ async def test_startup_error_message():
 )
 )
 @pytest.mark.asyncio
 @pytest.mark.asyncio
 async def test_transports(host_maddrs: List[Multiaddr]):
 async def test_transports(host_maddrs: List[Multiaddr]):
-    server = await P2P.create(quic=True, host_maddrs=host_maddrs)
+    server = await P2P.create(host_maddrs=host_maddrs)
     peers = await server.list_peers()
     peers = await server.list_peers()
     assert len(peers) == 0
     assert len(peers) == 0
 
 
-    client = await P2P.create(quic=True, host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
+    client = await P2P.create(host_maddrs=host_maddrs, initial_peers=await server.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
     peers = await client.list_peers()
     peers = await client.list_peers()