Răsfoiți Sursa

Check if identity is already taken (#511)

While using scripts built with hivemind, users often run two peers with the same identity by accident (e.g., if they forget to change the CLI command or copied the same identity file to another host via `scp`). Now, this leads to undefined behavior of libp2p.

This PR makes `hivemind.P2P` check if the identity is already taken, thus solving this issue in all applications at once.

(cherry picked from commit 64a6c302c8cc122e55cce348fb98482c61d32b37)
Alexander Borzunov 2 ani în urmă
părinte
comite
8c69f325f9

+ 51 - 2
hivemind/p2p/p2p_daemon.py

@@ -18,6 +18,7 @@ import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
 from hivemind.proto import crypto_pb2
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
@@ -102,6 +103,7 @@ class P2P:
         quic: Optional[bool] = None,
         use_relay_hop: Optional[bool] = None,
         use_relay_discovery: Optional[bool] = None,
+        check_if_identity_free: bool = True,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -129,6 +131,10 @@ class P2P:
         :param quic: Deprecated, has no effect since libp2p 0.17.0
         :param use_relay_hop: Deprecated, has no effect since libp2p 0.17.0
         :param use_relay_discovery: Deprecated, has no effect since libp2p 0.17.0
+        :param check_if_identity_free: If enabled (default) and ``identity_path`` is provided,
+                                       ensure that this identity is not used by other peers already.
+                                       This slows down ``P2P.create()`` but protects from unintuitive libp2p errors
+                                       appearing in case of the identity collision.
         :return: a wrapper for the p2p daemon
         """
 
@@ -169,9 +175,22 @@ class P2P:
                 process_kwargs[param] = self._maddrs_to_str(value)
 
         if identity_path is not None:
-            if not os.path.isfile(identity_path):
-                logger.info(f"Generating new identity (libp2p private key) in `{identity_path}`")
+            if os.path.isfile(identity_path):
+                if check_if_identity_free:
+                    logger.info(f"Checking that identity from `{identity_path}` is not used by other peers")
+                    if await cls.is_identity_taken(
+                        identity_path,
+                        initial_peers=initial_peers,
+                        tls=tls,
+                        use_auto_relay=use_auto_relay,
+                        use_ipfs=use_ipfs,
+                        use_relay=use_relay,
+                    ):
+                        raise P2PDaemonError(f"Identity from `{identity_path}` is already taken by another peer")
+            else:
+                logger.info(f"Generating new identity to be saved in `{identity_path}`")
                 self.generate_identity(identity_path)
+                # A newly generated identity is not taken with ~100% probability
             process_kwargs["id"] = identity_path
 
         proc_args = self._make_process_args(
@@ -217,6 +236,36 @@ class P2P:
         await self._ping_daemon()
         return self
 
+    @classmethod
+    async def is_identity_taken(
+        cls,
+        identity_path: str,
+        *,
+        initial_peers: Optional[Sequence[Union[Multiaddr, str]]],
+        tls: bool,
+        use_auto_relay: bool,
+        use_ipfs: bool,
+        use_relay: bool,
+    ) -> bool:
+        with open(identity_path, "rb") as f:
+            peer_id = PeerID.from_identity(f.read())
+
+        anonymous_p2p = await cls.create(
+            initial_peers=initial_peers,
+            dht_mode="client",
+            tls=tls,
+            use_auto_relay=use_auto_relay,
+            use_ipfs=use_ipfs,
+            use_relay=use_relay,
+        )
+        try:
+            await anonymous_p2p._client.connect(peer_id, [])
+            return True
+        except ControlFailure:
+            return False
+        finally:
+            await anonymous_p2p.shutdown()
+
     @staticmethod
     def generate_identity(identity_path: str) -> None:
         private_key = RSAPrivateKey()

+ 28 - 1
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -9,9 +9,10 @@ from typing import Any, Sequence, Union
 
 import base58
 import multihash
+from cryptography.hazmat.primitives import serialization
 from multiaddr import Multiaddr, protocols
 
-from hivemind.proto import p2pd_pb2
+from hivemind.proto import crypto_pb2, p2pd_pb2
 
 # NOTE: On inlining...
 # See: https://github.com/libp2p/specs/issues/138
@@ -88,6 +89,32 @@ class PeerID:
         peer_id_bytes = base58.b58decode(base58_id)
         return cls(peer_id_bytes)
 
+    @classmethod
+    def from_identity(cls, data: bytes) -> "PeerID":
+        """
+        See [1] for the specification of how this conversion should happen.
+
+        [1] https://github.com/libp2p/specs/blob/master/peer-ids/peer-ids.md#peer-ids
+        """
+
+        key_data = crypto_pb2.PrivateKey.FromString(data).data
+        private_key = serialization.load_der_private_key(key_data, password=None)
+
+        encoded_public_key = private_key.public_key().public_bytes(
+            encoding=serialization.Encoding.DER,
+            format=serialization.PublicFormat.SubjectPublicKeyInfo,
+        )
+        encoded_public_key = crypto_pb2.PublicKey(
+            key_type=crypto_pb2.RSA,
+            data=encoded_public_key,
+        ).SerializeToString()
+
+        algo = multihash.Func.sha2_256
+        if ENABLE_INLINING and len(encoded_public_key) <= MAX_INLINE_KEY_LENGTH:
+            algo = IDENTITY_MULTIHASH_CODE
+        encoded_digest = multihash.digest(encoded_public_key, algo).encode()
+        return cls(encoded_digest)
+
 
 def sha256_digest(data: Union[str, bytes]) -> bytes:
     if isinstance(data, str):

+ 26 - 0
tests/test_p2p_daemon.py

@@ -73,6 +73,32 @@ async def test_identity():
         P2P.generate_identity(id1_path)
 
 
+@pytest.mark.asyncio
+async def test_check_if_identity_free():
+    with tempfile.TemporaryDirectory() as tempdir:
+        id1_path = os.path.join(tempdir, "id1")
+        id2_path = os.path.join(tempdir, "id2")
+
+        p2ps = [await P2P.create(identity_path=id1_path)]
+        initial_peers = await p2ps[0].get_visible_maddrs()
+
+        p2ps.append(await P2P.create(initial_peers=initial_peers))
+        p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
+
+        with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
+            await P2P.create(initial_peers=initial_peers, identity_path=id1_path)
+        with pytest.raises(P2PDaemonError, match=r"Identity.+is already taken by another peer"):
+            await P2P.create(initial_peers=initial_peers, identity_path=id2_path)
+
+        # Must work if a P2P with a certain identity is restarted
+        await p2ps[-1].shutdown()
+        p2ps.pop()
+        p2ps.append(await P2P.create(initial_peers=initial_peers, identity_path=id2_path))
+
+        for instance in p2ps:
+            await instance.shutdown()
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [

+ 6 - 2
tests/test_start_server.py

@@ -34,8 +34,9 @@ def test_cli_run_server_identity_path():
             encoding="utf-8",
         )
 
-        # Skip line "Generating new identity (libp2p private key) in {path to file}"
-        server_1_proc.stderr.readline()
+        line = server_1_proc.stderr.readline()
+        assert "Generating new identity" in line
+
         line = server_1_proc.stderr.readline()
         addrs_pattern_result = re.search(pattern, line)
         assert addrs_pattern_result is not None, line
@@ -51,6 +52,9 @@ def test_cli_run_server_identity_path():
             encoding="utf-8",
         )
 
+        line = server_2_proc.stderr.readline()
+        assert re.search(r"Checking that identity.+is not used by other peers", line) is not None
+
         line = server_2_proc.stderr.readline()
         addrs_pattern_result = re.search(pattern, line)
         assert addrs_pattern_result is not None, line