Kaynağa Gözat

Generate new private key if identity file doesn't exist (#473)

Alexander Borzunov 3 yıl önce
ebeveyn
işleme
97deaee2f3

+ 9 - 0
examples/albert/README.md

@@ -130,6 +130,15 @@ monitors on different servers and list all of them as `--initial_peers`. The sys
 as at least one externally accessible participant is available. For short- to mid-term experiments you can host the
 monitor on a [free-tier VM](https://www.quora.com/Are-there-any-free-online-virtual-machines).
 
+By default, the training monitor changes its address on restart, so you may launch two monitors on the same machine.
+If you'd like to fix the monitor's address (e.g., before sending it to your collaborators),
+you need to **(a)** make it listen a specific TCP/UDP port and **(b)** provide a path for storing the identity file
+(which allows [libp2p](https://libp2p.io/) to reuse the same peer ID after restart). You may do that like this:
+
+```bash
+./run_training_monitor.py --wandb_project YOUR_WANDB_PROJECT --host_maddrs /ip4/0.0.0.0/tcp/31337 --identity_path ./identity.key
+```
+
 ### Tuning for hardware/network
 
 The optimal training parameters for each peer depend on its GPU and internet connection. If a peer cannot accept

+ 1 - 1
examples/albert/arguments.py

@@ -38,7 +38,7 @@ class BaseTrainingArguments:
         default=None,
         metadata={
             "help": "Path to a pre-generated private key file. If defined, makes the peer ID deterministic. "
-            "May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``."
+            "If the file does not exist yet, writes a new private key to this file."
         },
     )
 

+ 23 - 3
hivemind/p2p/p2p_daemon.py

@@ -18,8 +18,10 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.proto import crypto_pb2
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
+from hivemind.utils.crypto import RSAPrivateKey
 from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
 
 logger = get_logger(__name__)
@@ -113,8 +115,8 @@ class P2P:
                          Details: https://pkg.go.dev/github.com/libp2p/go-libp2p-kad-dht#ModeOpt
         :param force_reachability: Force reachability mode (public/private)
         :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
-        :param identity_path: Path to a pre-generated private key file. If defined, makes the peer ID deterministic.
-                              May be generated using ``./p2p-keygen`` from ``go-libp2p-daemon``.
+        :param identity_path: Path to a private key file. If defined, makes the peer ID deterministic.
+                              If the file does not exist yet, writes a new private key to this file.
         :param idle_timeout: kill daemon if client has been idle for a given number of
                              seconds before opening persistent streams
         :param nat_port_map: Enables NAT port mapping
@@ -156,7 +158,7 @@ class P2P:
                     raise ValueError("Please specify an explicit port in announce_maddrs: port 0 is not supported")
 
         need_bootstrap = bool(initial_peers) or use_ipfs
-        process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
+        process_kwargs = cls.DHT_MODE_MAPPING[dht_mode].copy()
         process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
         for param, value in [
             ("bootstrapPeers", initial_peers),
@@ -165,7 +167,11 @@ class P2P:
         ]:
             if value:
                 process_kwargs[param] = self._maddrs_to_str(value)
+
         if identity_path is not None:
+            if not os.path.isfile(identity_path):
+                logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+                self.generate_identity(identity_path)
             process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
@@ -211,6 +217,20 @@ class P2P:
         await self._ping_daemon()
         return self
 
+    @staticmethod
+    def generate_identity(identity_path: str) -> None:
+        private_key = RSAPrivateKey()
+        protobuf = crypto_pb2.PrivateKey(key_type=crypto_pb2.KeyType.RSA, data=private_key.to_bytes())
+
+        try:
+            with open(identity_path, "wb") as f:
+                f.write(protobuf.SerializeToString())
+        except FileNotFoundError:
+            raise FileNotFoundError(
+                f"The directory `{os.path.dirname(identity_path)}` for saving the identity does not exist"
+            )
+        os.chmod(identity_path, 0o400)
+
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """

+ 24 - 0
hivemind/proto/crypto.proto

@@ -0,0 +1,24 @@
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package crypto.pb;
+
+enum KeyType {
+  RSA = 0;
+  Ed25519 = 1;
+  Secp256k1 = 2;
+  ECDSA = 3;
+}
+
+message PublicKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}
+
+message PrivateKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}

+ 4 - 4
hivemind/proto/p2pd.proto

@@ -1,6 +1,6 @@
-//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
-//Licence: MIT
-//Author: Kevin Mai-Husan Chia
+// Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+// Licence: MIT
+// Author: Kevin Mai-Husan Chia
 
 syntax = "proto2";
 
@@ -15,7 +15,7 @@ message Request {
     DHT                      = 4;
     LIST_PEERS               = 5;
     CONNMANAGER              = 6;
-    DISCONNECT               = 7;      
+    DISCONNECT               = 7;
     PUBSUB                   = 8;
 
     PERSISTENT_CONN_UPGRADE  = 9;

+ 9 - 6
hivemind/utils/crypto.py

@@ -60,19 +60,22 @@ class RSAPrivateKey(PrivateKey):
     def get_public_key(self) -> RSAPublicKey:
         return RSAPublicKey(self._private_key.public_key())
 
+    def to_bytes(self) -> bytes:
+        return self._private_key.private_bytes(
+            encoding=serialization.Encoding.DER,
+            format=serialization.PrivateFormat.TraditionalOpenSSL,
+            encryption_algorithm=serialization.NoEncryption(),
+        )
+
     def __getstate__(self):
         state = self.__dict__.copy()
         # Serializes the private key to make the class instances picklable
-        state["_private_key"] = self._private_key.private_bytes(
-            encoding=serialization.Encoding.PEM,
-            format=serialization.PrivateFormat.OpenSSH,
-            encryption_algorithm=serialization.NoEncryption(),
-        )
+        state["_private_key"] = self.to_bytes()
         return state
 
     def __setstate__(self, state):
         self.__dict__.update(state)
-        self._private_key = serialization.load_ssh_private_key(self._private_key, password=None)
+        self._private_key = serialization.load_der_private_key(self._private_key, password=None)
 
 
 class RSAPublicKey(PublicKey):

+ 27 - 0
tests/test_p2p_daemon.py

@@ -1,6 +1,8 @@
 import asyncio
 import multiprocessing as mp
+import os
 import subprocess
+import tempfile
 from contextlib import closing
 from functools import partial
 from typing import List
@@ -45,6 +47,31 @@ async def test_startup_error_message():
         await P2P.create(startup_timeout=0.01)  # Test that startup_timeout works
 
 
+@pytest.mark.asyncio
+async def test_identity():
+    with tempfile.TemporaryDirectory() as tempdir:
+        id1_path = os.path.join(tempdir, "id1")
+        id2_path = os.path.join(tempdir, "id2")
+        p2ps = await asyncio.gather(*[P2P.create(identity_path=path) for path in [None, None, id1_path, id2_path]])
+
+        # We create the second daemon with id2 separately
+        # to avoid a race condition while saving a newly generated identity
+        p2ps.append(await P2P.create(identity_path=id2_path))
+
+        # Using the same identity (if any) should lead to the same peer ID
+        assert p2ps[-2].peer_id == p2ps[-1].peer_id
+
+        # The rest of peer IDs should be different
+        peer_ids = {instance.peer_id for instance in p2ps}
+        assert len(peer_ids) == 4
+
+        for instance in p2ps:
+            await instance.shutdown()
+
+    with pytest.raises(FileNotFoundError, match=r"The directory.+does not exist"):
+        P2P.generate_identity(id1_path)
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [