Ver Fonte

Py libp2p bindings (#193)

* #183 p2p daemon pybinding

* #183 rename py bindings dir, fix imports and migrate tests

* #183 move pb to hivemind.proto

* #183 fix p2p tests

* #183 remove config.py, move constants to classes

* add docstrings and minor fixes
MaximKsh há 4 anos atrás
pai
commit
3b5ce78828

+ 39 - 36
hivemind/p2p/p2p_daemon.py

@@ -8,8 +8,8 @@ import warnings
 
 import google.protobuf
 from multiaddr import Multiaddr
-import p2pclient
-from libp2p.peer.id import ID
+import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo
 
 from hivemind.utils.networking import find_open_port
 
@@ -104,78 +104,81 @@ class P2P(object):
                 self._daemon_listen_port = find_open_port()
 
     @staticmethod
-    async def send_raw_data(byte_str, stream):
+    async def send_raw_data(byte_str, writer):
         request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
-        await stream.send_all(request)
+        writer.write(request)
 
     @staticmethod
-    async def send_data(data, stream):
-        await P2P.send_raw_data(pickle.dumps(data), stream)
+    async def send_data(data, writer):
+        await P2P.send_raw_data(pickle.dumps(data), writer)
 
     @staticmethod
-    async def send_protobuf(protobuf, out_proto_type, stream):
+    async def send_protobuf(protobuf, out_proto_type, writer):
         if type(protobuf) != out_proto_type:
             error = TypeError('Unary handler returned protobuf of wrong type.')
-            await P2P.send_raw_data(pickle.dumps(error), stream)
+            await P2P.send_raw_data(pickle.dumps(error), writer)
             raise error
-        await P2P.send_raw_data(protobuf.SerializeToString(), stream)
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
 
     @staticmethod
-    async def receive_exactly(stream, n_bytes, max_bytes=1 << 16):
+    async def receive_exactly(reader, n_bytes, max_bytes=1 << 16):
         buffer = bytearray()
         while len(buffer) < n_bytes:
-            data = await stream.receive_some(min(max_bytes, n_bytes - len(buffer)))
+            data = await reader.read(min(max_bytes, n_bytes - len(buffer)))
             if len(data) == 0:
                 raise P2P.IncompleteRead()
             buffer.extend(data)
         return bytes(buffer)
 
     @staticmethod
-    async def receive_raw_data(stream):
-        header = await P2P.receive_exactly(stream, P2P.HEADER_LEN)
+    async def receive_raw_data(reader):
+        header = await P2P.receive_exactly(reader, P2P.HEADER_LEN)
         content_length = int.from_bytes(header, P2P.BYTEORDER)
-        data = await P2P.receive_exactly(stream, content_length)
+        data = await P2P.receive_exactly(reader, content_length)
         return data
 
     @staticmethod
-    async def receive_data(stream):
-        return pickle.loads(await P2P.receive_raw_data(stream))
+    async def receive_data(reader):
+        return pickle.loads(await P2P.receive_raw_data(reader))
 
     @staticmethod
-    async def receive_protobuf(in_proto_type, stream):
+    async def receive_protobuf(in_proto_type, reader):
         protobuf = in_proto_type()
-        protobuf.ParseFromString(await P2P.receive_raw_data(stream))
+        protobuf.ParseFromString(await P2P.receive_raw_data(reader))
         return protobuf
 
     @staticmethod
     def _handle_stream(handle):
-        async def do_handle_stream(stream_info, stream):
+        async def do_handle_stream(stream_info, reader, writer):
             try:
-                request = await P2P.receive_data(stream)
+                request = await P2P.receive_data(reader)
             except P2P.IncompleteRead:
                 warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
-                await stream.close()
+                writer.close()
                 return
             try:
                 result = handle(request)
-                await P2P.send_data(result, stream)
+                await P2P.send_data(result, writer)
             except Exception as exc:
-                await P2P.send_data(exc, stream)
+                await P2P.send_data(exc, writer)
             finally:
-                await stream.close()
+                writer.close()
 
         return do_handle_stream
 
     @staticmethod
     def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
-        async def watchdog(stream):
-            await stream.receive_some(max_bytes=1)
+        async def watchdog(reader: asyncio.StreamReader):
+            await reader.read(n=1)
             raise P2P.InterruptedError()
 
-        async def do_handle_unary_stream(stream_info, stream):
+        async def do_handle_unary_stream(
+                stream_info: StreamInfo,
+                reader: asyncio.StreamReader,
+                writer: asyncio.StreamWriter) -> None:
             try:
                 try:
-                    request = await P2P.receive_protobuf(in_proto_type, stream)
+                    request = await P2P.receive_protobuf(in_proto_type, reader)
                 except P2P.IncompleteRead:
                     warnings.warn("Incomplete read while receiving request from peer",
                                   RuntimeWarning)
@@ -185,15 +188,15 @@ class P2P(object):
                     return
 
                 context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
-                done, pending = await asyncio.wait([watchdog(stream), handle(request, context)],
+                done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
                                                    return_when=asyncio.FIRST_COMPLETED)
                 try:
                     result = done.pop().result()
-                    await P2P.send_protobuf(result, out_proto_type, stream)
+                    await P2P.send_protobuf(result, out_proto_type, writer)
                 except P2P.InterruptedError:
                     pass
                 except Exception as exc:
-                    await P2P.send_data(exc, stream)
+                    await P2P.send_data(exc, writer)
                 finally:
                     pending_task = pending.pop()
                     pending_task.cancel()
@@ -202,7 +205,7 @@ class P2P(object):
                     except asyncio.CancelledError:
                         pass
             finally:
-                await stream.close()
+                writer.close()
 
         return do_handle_unary_stream
 
@@ -237,12 +240,12 @@ class P2P(object):
 
     async def call_peer_handler(self, peer_id, handler_name, input_data):
         libp2p_peer_id = ID.from_base58(peer_id)
-        stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
         try:
-            await P2P.send_data(input_data, stream)
-            return await P2P.receive_data(stream)
+            await P2P.send_data(input_data, writer)
+            return await P2P.receive_data(reader)
         finally:
-            await stream.close()
+            writer.close()
 
     def __del__(self):
         self._kill_child()

+ 0 - 0
hivemind/p2p/p2p_daemon_bindings/__init__.py


+ 211 - 0
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -0,0 +1,211 @@
+import logging
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+
+import asyncio
+from contextlib import asynccontextmanager
+from multiaddr import Multiaddr, protocols
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo, StreamInfo, ID
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, read_pbmsg_safe, write_pbmsg, raise_if_failed
+
+StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
+
+_supported_conn_protocols = (
+    protocols.P_IP4,
+    # protocols.P_IP6,
+    protocols.P_UNIX,
+)
+
+
+def parse_conn_protocol(maddr: Multiaddr) -> int:
+    proto_codes = set(proto.code for proto in maddr.protocols())
+    proto_cand = proto_codes.intersection(_supported_conn_protocols)
+    if len(proto_cand) != 1:
+        supported_protos = (
+            protocols.protocol_with_code(proto) for proto in _supported_conn_protocols
+        )
+        raise ValueError(
+            f"connection protocol should be only one protocol out of {supported_protos}"
+            f", maddr={maddr}"
+        )
+    return tuple(proto_cand)[0]
+
+
+class DaemonConnector:
+    control_maddr: Multiaddr
+    logger = logging.getLogger("p2pclient.DaemonConnector")
+    DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
+
+    def __init__(self, control_maddr: Multiaddr = None) -> None:
+        if control_maddr is None:
+            control_maddr = Multiaddr(self.DEFAULT_CONTROL_MADDR)
+        self.control_maddr = control_maddr
+
+    async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        proto_code = parse_conn_protocol(self.control_maddr)
+        if proto_code == protocols.P_UNIX:
+            control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
+            self.logger.debug(
+                "DaemonConnector %s opens connection to %s", self, self.control_maddr
+            )
+            return await asyncio.open_unix_connection(control_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.control_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
+            return await asyncio.open_connection(host, port)
+        else:
+            raise ValueError(
+                f"protocol not supported: protocol={protocols.protocol_with_code(proto_code)}"
+            )
+
+
+class ControlClient:
+    listen_maddr: Multiaddr
+    daemon_connector: DaemonConnector
+    handlers: Dict[str, StreamHandler]
+    logger = logging.getLogger("p2pclient.ControlClient")
+    DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
+
+    def __init__(
+        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = None
+    ) -> None:
+        if listen_maddr is None:
+            listen_maddr = Multiaddr(self.DEFAULT_LISTEN_MADDR)
+        self.listen_maddr = listen_maddr
+        self.daemon_connector = daemon_connector
+        self.handlers = {}
+
+    async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
+        await read_pbmsg_safe(reader, pb_stream_info)
+        stream_info = StreamInfo.from_pb(pb_stream_info)
+        self.logger.info("New incoming stream: %s", stream_info)
+        try:
+            handler = self.handlers[stream_info.proto]
+        except KeyError as e:
+            # should never enter here... daemon should reject the stream for us.
+            writer.close()
+            raise DispatchFailure(e)
+        await handler(stream_info, reader, writer)
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["ControlClient"]:
+        proto_code = parse_conn_protocol(self.listen_maddr)
+        if proto_code == protocols.P_UNIX:
+            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
+            server = await asyncio.start_unix_server(self._handler, path=listen_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
+            server = await asyncio.start_server(self._handler, port=port, host=host)
+        else:
+            raise ValueError(
+                f"protocol not supported: protocol={protocols.protocol_with_code(proto_code)}"
+            )
+
+        async with server:
+            self.logger.info(
+                "DaemonConnector %s starts listening to %s", self, self.listen_maddr
+            )
+            yield self
+
+        self.logger.info("DaemonConnector %s closed", self)
+
+    async def identify(self) -> Tuple[ID, Tuple[Multiaddr, ...]]:
+        reader, writer = await self.daemon_connector.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+
+        raise_if_failed(resp)
+        peer_id_bytes = resp.identify.id
+        maddrs_bytes = resp.identify.addrs
+
+        maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
+        peer_id = ID(peer_id_bytes)
+
+        return peer_id, maddrs
+
+    async def connect(self, peer_id: ID, maddrs: Iterable[Multiaddr]) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        maddrs_bytes = [i.to_bytes() for i in maddrs]
+        connect_req = p2pd_pb.ConnectRequest(
+            peer=peer_id.to_bytes(), addrs=maddrs_bytes
+        )
+        req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        peers = tuple(PeerInfo.from_pb(pinfo) for pinfo in resp.peers)
+        return peers
+
+    async def disconnect(self, peer_id: ID) -> None:
+        disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req
+        )
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def stream_open(
+        self, peer_id: ID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        stream_open_req = p2pd_pb.StreamOpenRequest(
+            peer=peer_id.to_bytes(), proto=list(protocols)
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        raise_if_failed(resp)
+
+        pb_stream_info = resp.streamInfo
+        stream_info = StreamInfo.from_pb(pb_stream_info)
+
+        return stream_info, reader, writer
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        listen_path_maddr_bytes = self.listen_maddr.to_bytes()
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto]
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        # if success, add the handler to the dict
+        self.handlers[proto] = handler_cb

+ 186 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -0,0 +1,186 @@
+import hashlib
+from typing import Union, List, Sequence, Any
+
+import base58
+import multihash
+
+from multiaddr import Multiaddr, protocols
+from hivemind.proto import p2pd_pb2
+
+from hivemind.p2p.p2p_daemon_bindings.keys import PublicKey
+
+# NOTE: On inlining...
+# See: https://github.com/libp2p/specs/issues/138
+# NOTE: enabling to be interoperable w/ the Go implementation
+ENABLE_INLINING = True
+MAX_INLINE_KEY_LENGTH = 42
+
+IDENTITY_MULTIHASH_CODE = 0x00
+
+if ENABLE_INLINING:
+
+    class IdentityHash:
+        _digest: bytes
+
+        def __init__(self) -> None:
+            self._digest = bytearray()
+
+        def update(self, input: bytes) -> None:
+            self._digest += input
+
+        def digest(self) -> bytes:
+            return self._digest
+
+    multihash.FuncReg.register(
+        IDENTITY_MULTIHASH_CODE, "identity", hash_new=lambda: IdentityHash()
+    )
+
+
+class ID:
+    _bytes: bytes
+    _xor_id: int = None
+    _b58_str: str = None
+
+    def __init__(self, peer_id_bytes: bytes) -> None:
+        self._bytes = peer_id_bytes
+
+    @property
+    def xor_id(self) -> int:
+        if not self._xor_id:
+            self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
+        return self._xor_id
+
+    def to_bytes(self) -> bytes:
+        return self._bytes
+
+    def to_base58(self) -> str:
+        if not self._b58_str:
+            self._b58_str = base58.b58encode(self._bytes).decode()
+        return self._b58_str
+
+    def __repr__(self) -> str:
+        return f"<libp2p.peer.id.ID ({self!s})>"
+
+    __str__ = pretty = to_string = to_base58
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, str):
+            return self.to_base58() == other
+        elif isinstance(other, bytes):
+            return self._bytes == other
+        elif isinstance(other, ID):
+            return self._bytes == other._bytes
+        else:
+            return NotImplemented
+
+    def __hash__(self) -> int:
+        return hash(self._bytes)
+
+    @classmethod
+    def from_base58(cls, b58_encoded_peer_id_str: str) -> "ID":
+        peer_id_bytes = base58.b58decode(b58_encoded_peer_id_str)
+        pid = ID(peer_id_bytes)
+        return pid
+
+    @classmethod
+    def from_pubkey(cls, key: PublicKey) -> "ID":
+        serialized_key = key.serialize()
+        algo = multihash.Func.sha2_256
+        if ENABLE_INLINING and len(serialized_key) <= MAX_INLINE_KEY_LENGTH:
+            algo = IDENTITY_MULTIHASH_CODE
+        mh_digest = multihash.digest(serialized_key, algo)
+        return cls(mh_digest.encode())
+
+
+def sha256_digest(data: Union[str, bytes]) -> bytes:
+    if isinstance(data, str):
+        data = data.encode("utf8")
+    return hashlib.sha256(data).digest()
+
+
+class StreamInfo:
+    peer_id: ID
+    addr: Multiaddr
+    proto: str
+
+    def __init__(self, peer_id: ID, addr: Multiaddr, proto: str) -> None:
+        self.peer_id = peer_id
+        self.addr = addr
+        self.proto = proto
+
+    def __repr__(self) -> str:
+        return (
+            f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
+        )
+
+    def to_pb(self) -> p2pd_pb2.StreamInfo:
+        pb_msg = p2pd_pb2.StreamInfo(
+            peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto
+        )
+        return pb_msg
+
+    @classmethod
+    def from_pb(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo":
+        stream_info = cls(
+            peer_id=ID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto
+        )
+        return stream_info
+
+
+class PeerInfoLibP2P:
+    peer_id: ID
+    addrs: List[Multiaddr]
+
+    def __init__(self, peer_id: ID, addrs: Sequence[Multiaddr]) -> None:
+        self.peer_id = peer_id
+        self.addrs = list(addrs)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, PeerInfo)
+            and self.peer_id == other.peer_id
+            and self.addrs == other.addrs
+        )
+
+
+def info_from_p2p_addr(addr: Multiaddr) -> PeerInfoLibP2P:
+    if not addr:
+        raise InvalidAddrError("`addr` should not be `None`")
+
+    parts = addr.split()
+    if not parts:
+        raise InvalidAddrError(
+            f"`parts`={parts} should at least have a protocol `P_P2P`"
+        )
+
+    p2p_part = parts[-1]
+    last_protocol_code = p2p_part.protocols()[0].code
+    if last_protocol_code != protocols.P_P2P:
+        raise InvalidAddrError(
+            f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
+        )
+
+    # make sure the /p2p value parses as a peer.ID
+    peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P)
+    peer_id: ID = ID.from_base58(peer_id_str)
+
+    # we might have received just an / p2p part, which means there's no addr.
+    if len(parts) > 1:
+        addr = Multiaddr.join(*parts[:-1])
+
+    return PeerInfo(peer_id, [addr])
+
+
+class InvalidAddrError(ValueError):
+    pass
+
+
+class PeerInfo(PeerInfoLibP2P):
+    @classmethod
+    def from_pb(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> PeerInfoLibP2P:
+        peer_id = ID(peer_info_pb.id)
+        addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
+        return PeerInfo(peer_id, addrs)
+
+    def __str__(self):
+        return self.peer_id.pretty() + " " + ",".join(str(a) for a in self.addrs)

+ 91 - 0
hivemind/p2p/p2p_daemon_bindings/keys.py

@@ -0,0 +1,91 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from enum import Enum, unique
+
+from hivemind.proto import crypto_pb2 as protobuf
+
+
+@unique
+class KeyType(Enum):
+    RSA = 0
+    Ed25519 = 1
+    Secp256k1 = 2
+    ECDSA = 3
+    ECC_P256 = 4
+
+
+class Key(ABC):
+    """A ``Key`` represents a cryptographic key."""
+
+    @abstractmethod
+    def to_bytes(self) -> bytes:
+        """Returns the byte representation of this key."""
+        ...
+
+    @abstractmethod
+    def get_type(self) -> KeyType:
+        """Returns the ``KeyType`` for ``self``."""
+        ...
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, Key):
+            return NotImplemented
+        return self.to_bytes() == other.to_bytes()
+
+
+class PublicKey(Key):
+    """A ``PublicKey`` represents a cryptographic public key."""
+
+    @abstractmethod
+    def verify(self, data: bytes, signature: bytes) -> bool:
+        """Verify that ``signature`` is the cryptographic signature of the hash
+        of ``data``."""
+        ...
+
+    def _serialize_to_protobuf(self) -> protobuf.PublicKey:
+        """Return the protobuf representation of this ``Key``."""
+        key_type = self.get_type().value
+        data = self.to_bytes()
+        protobuf_key = protobuf.PublicKey(key_type=key_type, data=data)
+        return protobuf_key
+
+    def serialize(self) -> bytes:
+        """Return the canonical serialization of this ``Key``."""
+        return self._serialize_to_protobuf().SerializeToString()
+
+    @classmethod
+    def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PublicKey:
+        return protobuf.PublicKey.FromString(protobuf_data)
+
+
+class PrivateKey(Key):
+    """A ``PrivateKey`` represents a cryptographic private key."""
+
+    @abstractmethod
+    def sign(self, data: bytes) -> bytes:
+        ...
+
+    @abstractmethod
+    def get_public_key(self) -> PublicKey:
+        ...
+
+    def _serialize_to_protobuf(self) -> protobuf.PrivateKey:
+        """Return the protobuf representation of this ``Key``."""
+        key_type = self.get_type().value
+        data = self.to_bytes()
+        protobuf_key = protobuf.PrivateKey(key_type=key_type, data=data)
+        return protobuf_key
+
+    def serialize(self) -> bytes:
+        """Return the canonical serialization of this ``Key``."""
+        return self._serialize_to_protobuf().SerializeToString()
+
+    @classmethod
+    def deserialize_from_protobuf(cls, protobuf_data: bytes) -> protobuf.PrivateKey:
+        return protobuf.PrivateKey.FromString(protobuf_data)
+
+
+@dataclass(frozen=True)
+class KeyPair:
+    private_key: PrivateKey
+    public_key: PublicKey

+ 75 - 0
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -0,0 +1,75 @@
+from typing import AsyncIterator, Iterable, Sequence, Tuple
+
+import asyncio
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from contextlib import asynccontextmanager
+from multiaddr import Multiaddr
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerInfo, StreamInfo, ID
+
+
+class Client:
+    control: ControlClient
+
+    def __init__(
+        self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None
+    ) -> None:
+        daemon_connector = DaemonConnector(control_maddr=control_maddr)
+        self.control = ControlClient(
+            daemon_connector=daemon_connector, listen_maddr=listen_maddr
+        )
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["Client"]:
+        """
+        Starts to listen incoming connections for handlers registered via stream_handler.
+        :return:
+        """
+        async with self.control.listen():
+            yield self
+
+    async def identify(self) -> Tuple[ID, Tuple[Multiaddr, ...]]:
+        """
+        Get current node peer id and list of addresses
+        """
+        return await self.control.identify()
+
+    async def connect(self, peer_id: ID, maddrs: Iterable[Multiaddr]) -> None:
+        """
+        Connect to p2p node with specified addresses and peer id.
+        :peer_id: node peer id you want connect to
+        :maddrs: node multiaddresses you want connect to. Of course, it must be reachable.
+        """
+        await self.control.connect(peer_id=peer_id, maddrs=maddrs)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        """
+        Get list of peers that node connect to
+        """
+        return await self.control.list_peers()
+
+    async def disconnect(self, peer_id: ID) -> None:
+        """
+        Disconnect from node with specified peer id
+        :peer_id:
+        """
+        await self.control.disconnect(peer_id=peer_id)
+
+    async def stream_open(
+        self, peer_id: ID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        """
+        Open a stream to call other peer (with peer_id) handler for specified protocols
+        :peer_id:
+        :protocols:
+        :return: Returns tuple of stream info (info about connection to second peer) and reader/writer
+        """
+        return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        """
+        Register a stream handler
+        :param proto: protocols that handler serves
+        :param handler_cb: handler callback
+        :return:
+        """
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)

+ 72 - 0
hivemind/p2p/p2p_daemon_bindings/utils.py

@@ -0,0 +1,72 @@
+import asyncio
+
+from google.protobuf.message import Message as PBMessage
+
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+
+
+DEFAULT_MAX_BITS: int = 64
+
+
+class ControlFailure(Exception):
+    pass
+
+
+class DispatchFailure(Exception):
+    pass
+
+
+async def write_unsigned_varint(
+    stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS
+) -> None:
+    max_int: int = 1 << max_bits
+    if integer < 0:
+        raise ValueError(f"negative integer: {integer}")
+    if integer >= max_int:
+        raise ValueError(f"integer too large: {integer}")
+    while True:
+        value: int = integer & 0x7F
+        integer >>= 7
+        if integer != 0:
+            value |= 0x80
+        byte = value.to_bytes(1, "big")
+        stream.write(byte)
+        if integer == 0:
+            break
+
+
+async def read_unsigned_varint(
+    stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS
+) -> int:
+    max_int: int = 1 << max_bits
+    iteration: int = 0
+    result: int = 0
+    has_next: bool = True
+    while has_next:
+        data = await stream.readexactly(1)
+        c = data[0]
+        value = c & 0x7F
+        result |= value << (iteration * 7)
+        has_next = (c & 0x80) != 0
+        iteration += 1
+        if result >= max_int:
+            raise ValueError(f"varint overflowed: {result}")
+    return result
+
+
+def raise_if_failed(response: p2pd_pb.Response) -> None:
+    if response.type == p2pd_pb.Response.ERROR:
+        raise ControlFailure(f"connect failed. msg={response.error.msg}")
+
+
+async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
+    size = pbmsg.ByteSize()
+    await write_unsigned_varint(stream, size)
+    msg_bytes: bytes = pbmsg.SerializeToString()
+    stream.write(msg_bytes)
+
+
+async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
+    len_msg_bytes = await read_unsigned_varint(stream)
+    msg_bytes = await stream.readexactly(len_msg_bytes)
+    pbmsg.ParseFromString(msg_bytes)

+ 20 - 0
hivemind/proto/crypto.proto

@@ -0,0 +1,20 @@
+syntax = "proto2";
+
+package crypto.pb;
+
+enum KeyType {
+  RSA = 0;
+  Ed25519 = 1;
+  Secp256k1 = 2;
+  ECDSA = 3;
+}
+
+message PublicKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}
+
+message PrivateKey {
+  required KeyType key_type = 1;
+  required bytes data = 2;
+}

+ 158 - 0
hivemind/proto/p2pd.proto

@@ -0,0 +1,158 @@
+syntax = "proto2";
+
+package p2pclient.p2pd.pb;
+
+message Request {
+  enum Type {
+    IDENTIFY       = 0;
+    CONNECT        = 1;
+    STREAM_OPEN    = 2;
+    STREAM_HANDLER = 3;
+    DHT            = 4;
+    LIST_PEERS     = 5;
+    CONNMANAGER    = 6;
+    DISCONNECT     = 7;
+    PUBSUB         = 8;
+  }
+
+  required Type type = 1;
+
+  optional ConnectRequest connect = 2;
+  optional StreamOpenRequest streamOpen = 3;
+  optional StreamHandlerRequest streamHandler = 4;
+  optional DHTRequest dht = 5;
+  optional ConnManagerRequest connManager = 6;
+  optional DisconnectRequest disconnect = 7;
+  optional PSRequest pubsub = 8;
+}
+
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
+message IdentifyResponse {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnectRequest {
+  required bytes peer = 1;
+  repeated bytes addrs = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamOpenRequest {
+  required bytes peer = 1;
+  repeated string proto = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamHandlerRequest {
+  required bytes addr = 1;
+  repeated string proto = 2;
+}
+
+message ErrorResponse {
+  required string msg = 1;
+}
+
+message StreamInfo {
+  required bytes peer = 1;
+  required bytes addr = 2;
+  required string proto = 3;
+}
+
+message DHTRequest {
+  enum Type {
+    FIND_PEER                    = 0;
+    FIND_PEERS_CONNECTED_TO_PEER = 1;
+    FIND_PROVIDERS               = 2;
+    GET_CLOSEST_PEERS            = 3;
+    GET_PUBLIC_KEY               = 4;
+    GET_VALUE                    = 5;
+    SEARCH_VALUE                 = 6;
+    PUT_VALUE                    = 7;
+    PROVIDE                      = 8;
+  }
+
+  required Type type = 1;
+  optional bytes peer = 2;
+  optional bytes cid = 3;
+  optional bytes key = 4;
+  optional bytes value = 5;
+  optional int32 count = 6;
+  optional int64 timeout = 7;
+}
+
+message DHTResponse {
+  enum Type {
+    BEGIN = 0;
+    VALUE = 1;
+    END   = 2;
+  }
+
+  required Type type = 1;
+  optional PeerInfo peer = 2;
+  optional bytes value = 3;
+}
+
+message PeerInfo {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnManagerRequest {
+  enum Type {
+    TAG_PEER        = 0;
+    UNTAG_PEER      = 1;
+    TRIM            = 2;
+  }
+
+  required Type type = 1;
+
+  optional bytes peer = 2;
+  optional string tag = 3;
+  optional int64 weight = 4;
+}
+
+message DisconnectRequest {
+  required bytes peer = 1;
+}
+
+message PSRequest {
+  enum Type {
+    GET_TOPICS = 0;
+    LIST_PEERS = 1;
+    PUBLISH    = 2;
+    SUBSCRIBE  = 3;
+  }
+
+  required Type type = 1;
+  optional string topic = 2;
+  optional bytes data = 3;
+}
+
+message PSMessage {
+  optional bytes from_id = 1;
+  optional bytes data = 2;
+  optional bytes seqno = 3;
+  repeated string topicIDs = 4;
+  optional bytes signature = 5;
+  optional bytes key = 6;
+}
+
+message PSResponse {
+  repeated string topics = 1;
+  repeated bytes peerIDs = 2;
+}

+ 2 - 0
requirements.txt

@@ -10,4 +10,6 @@ grpcio>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
 configargparse>=1.2.3
+multiaddr==0.0.9
+pymultihash==0.8.2
 cryptography>=3.4.6

+ 25 - 23
tests/test_p2p_daemon.py

@@ -2,7 +2,7 @@ import asyncio
 import multiprocessing as mp
 import subprocess
 
-from libp2p.peer.id import ID
+from hivemind.p2p.p2p_daemon_bindings.datastructures import ID
 
 import numpy as np
 import pytest
@@ -86,16 +86,16 @@ async def test_call_unary_handler(should_cancel, handle_name="handle"):
 
     await asyncio.sleep(1)
     libp2p_server_id = ID.from_base58(server.id)
-    stream_info, stream = await client._client.stream_open(libp2p_server_id, (handle_name,))
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
 
-    await P2P.send_raw_data(ping_request.SerializeToString(), stream)
+    await P2P.send_raw_data(ping_request.SerializeToString(), writer)
 
     if should_cancel:
-        await stream.close()
+        writer.close()
         await asyncio.sleep(1)
         assert handler_cancelled
     else:
-        result = await P2P.receive_protobuf(dht_pb2.PingResponse, stream)
+        result = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
         assert result == expected_response
         assert not handler_cancelled
 
@@ -139,6 +139,25 @@ async def test_call_peer_single_process(test_input, handle, handler_name="handle
     assert not is_process_running(client_pid)
 
 
+async def run_server(handler_name, server_side, client_side, response_received):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle_square)
+    assert is_process_running(server_pid)
+
+    server_side.send(server.id)
+    while response_received.value == 0:
+        await asyncio.sleep(0.5)
+
+    await server.stop_listening()
+    server.__del__()
+    assert not is_process_running(server_pid)
+
+
+def server_target(handler_name, server_side, client_side, response_received):
+    asyncio.run(run_server(handler_name, server_side, client_side, response_received))
+
+
 @pytest.mark.asyncio
 async def test_call_peer_different_processes():
     handler_name = "square"
@@ -148,24 +167,7 @@ async def test_call_peer_different_processes():
     response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
     response_received.value = 0
 
-    async def run_server():
-        server = await P2P.create()
-        server_pid = server._child.pid
-        await server.add_stream_handler(handler_name, handle_square)
-        assert is_process_running(server_pid)
-
-        server_side.send(server.id)
-        while response_received.value == 0:
-            await asyncio.sleep(0.5)
-
-        await server.stop_listening()
-        server.__del__()
-        assert not is_process_running(server_pid)
-
-    def server_target():
-        asyncio.run(run_server())
-
-    proc = mp.Process(target=server_target)
+    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
     proc.start()
 
     client = await P2P.create()

+ 769 - 0
tests/test_p2p_daemon_bindings.py

@@ -0,0 +1,769 @@
+import asyncio
+import functools
+import io
+import os
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager, AsyncExitStack
+from typing import NamedTuple
+
+from google.protobuf.message import EncodeError
+from multiaddr import Multiaddr, protocols
+
+import pytest
+
+from hivemind import find_open_port
+from hivemind.p2p.p2p_daemon_bindings.control import parse_conn_protocol, DaemonConnector, ControlClient
+from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
+from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure, raise_if_failed, write_unsigned_varint, \
+    read_unsigned_varint, read_pbmsg_safe, write_pbmsg
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from hivemind.p2p.p2p_daemon_bindings.datastructures import ID, StreamInfo, PeerInfo
+
+
+def test_raise_if_failed_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.ERROR
+    with pytest.raises(ControlFailure):
+        raise_if_failed(resp)
+
+
+def test_raise_if_failed_not_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.OK
+    raise_if_failed(resp)
+
+
+pairs_int_varint_valid = (
+    (0, b"\x00"),
+    (1, b"\x01"),
+    (128, b"\x80\x01"),
+    (2 ** 32, b"\x80\x80\x80\x80\x10"),
+    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+)
+
+pairs_int_varint_overflow = (
+    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (
+        2 ** 128,
+        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
+    ),
+)
+
+
+class MockReader(io.BytesIO):
+    async def readexactly(self, n):
+        await asyncio.sleep(0)
+        return self.read(n)
+
+
+class MockWriter(io.BytesIO):
+    pass
+
+
+class MockReaderWriter(MockReader, MockWriter):
+    pass
+
+
+@pytest.mark.parametrize("integer, var_integer", pairs_int_varint_valid)
+@pytest.mark.asyncio
+async def test_write_unsigned_varint(integer, var_integer):
+    s = MockWriter()
+    await write_unsigned_varint(s, integer)
+    assert s.getvalue() == var_integer
+
+
+@pytest.mark.parametrize("integer", tuple(i[0] for i in pairs_int_varint_overflow))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_overflow(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_negative(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer, var_integer", pairs_int_varint_valid)
+@pytest.mark.asyncio
+async def test_read_unsigned_varint(integer, var_integer):
+    s = MockReader(var_integer)
+    result = await read_unsigned_varint(s)
+    assert result == integer
+
+
+@pytest.mark.parametrize("var_integer", tuple(i[1] for i in pairs_int_varint_overflow))
+@pytest.mark.asyncio
+async def test_read_unsigned_varint_overflow(var_integer):
+    s = MockReader(var_integer)
+    with pytest.raises(ValueError):
+        await read_unsigned_varint(s)
+
+
+@pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
+@pytest.mark.asyncio
+async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
+    """
+    Test the edge with different `max_bits`
+    """
+    for i in range(-3, 0):
+        integer = i + (2 ** max_bits)
+        s = MockReaderWriter()
+        await write_unsigned_varint(s, integer, max_bits=max_bits)
+        s.seek(0, 0)
+        result = await read_unsigned_varint(s, max_bits=max_bits)
+        assert integer == result
+
+
+@pytest.fixture(scope="module")
+def peer_id_string():
+    return "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
+
+
+@pytest.fixture(scope="module")
+def peer_id_bytes():
+    return b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
+
+
+@pytest.fixture(scope="module")
+def peer_id(peer_id_bytes):
+    return ID(peer_id_bytes)
+
+
+@pytest.fixture(scope="module")
+def maddr():
+    return Multiaddr("/unix/123")
+
+
+def test_peer_id(peer_id_string, peer_id_bytes, peer_id):
+    # test initialized with bytes
+    assert peer_id.to_bytes() == peer_id_bytes
+    assert peer_id.to_string() == peer_id_string
+    # test initialized with string
+    peer_id_2 = ID.from_base58(peer_id_string)
+    assert peer_id_2.to_bytes() == peer_id_bytes
+    assert peer_id_2.to_string() == peer_id_string
+    # test equal
+    assert peer_id == peer_id_2
+    # test not equal
+    peer_id_3 = ID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
+    assert peer_id != peer_id_3
+
+
+def test_stream_info(peer_id, maddr):
+    proto = "123"
+    # test case: `StreamInfo.__init__`
+    si = StreamInfo(peer_id, maddr, proto)
+    assert si.peer_id == peer_id
+    assert si.addr == maddr
+    assert si.proto == proto
+    # test case: `StreamInfo.to_pb`
+    pb_si = si.to_pb()
+    assert pb_si.peer == peer_id.to_bytes()
+    assert pb_si.addr == maddr.to_bytes()
+    assert pb_si.proto == si.proto
+    # test case: `StreamInfo.from_pb`
+    si_1 = StreamInfo.from_pb(pb_si)
+    assert si_1.peer_id == peer_id
+    assert si_1.addr == maddr
+    assert si_1.proto == proto
+
+
+def test_peer_info(peer_id, maddr):
+    pi = PeerInfo(peer_id, [maddr])
+    # test case: `PeerInfo.__init__`
+    assert pi.peer_id == peer_id
+    assert pi.addrs == [maddr]
+    # test case: `PeerInfo.from_pb`
+    pi_pb = p2pd_pb.PeerInfo(id=peer_id.to_bytes(), addrs=[maddr.to_bytes()])
+    pi_1 = PeerInfo.from_pb(pi_pb)
+    assert pi.peer_id == pi_1.peer_id
+    assert pi.addrs == pi_1.addrs
+
+
+@pytest.mark.parametrize(
+    "maddr_str, expected_proto",
+    (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
+)
+def test_parse_conn_protocol_valid(maddr_str, expected_proto):
+    assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
+
+
+@pytest.mark.parametrize(
+    "maddr_str",
+    (
+        "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
+        "/onion/timaq4ygg2iegci7:1234",
+    ),
+)
+def test_parse_conn_protocol_invalid(maddr_str):
+    maddr = Multiaddr(maddr_str)
+    with pytest.raises(ValueError):
+        parse_conn_protocol(maddr)
+
+
+@pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_client_ctor_control_maddr(control_maddr_str):
+    c = DaemonConnector(Multiaddr(control_maddr_str))
+    assert c.control_maddr == Multiaddr(control_maddr_str)
+
+
+def test_client_ctor_default_control_maddr():
+    c = DaemonConnector()
+    assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
+
+
+@pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_control_client_ctor_listen_maddr(listen_maddr_str):
+    c = ControlClient(
+        daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
+    )
+    assert c.listen_maddr == Multiaddr(listen_maddr_str)
+
+
+def test_control_client_ctor_default_listen_maddr():
+    c = ControlClient(daemon_connector=DaemonConnector())
+    assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
+
+
+@pytest.mark.parametrize(
+    "msg_bytes",
+    (
+        b'\x08\x00"R\n"\x12 F\xec\xd3p0X\xbeT\x95p^\xc8{\xc8\x13\xa3\x9c\x84d\x0b\x1b\xbb\xa0P\x98w\xc1\xb3\x981i\x16\x12\x02\xa2\x02\x12\x08\x04\x7f\x00\x00\x01\x06\xc7\xb6\x12\x08\x04\xc0\xa8\n\x87\x06\xc7\xb6\x12\x14)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x06\xc7\xb7',  # noqa: E501
+        b'\x08\x00"R\n"\x12 \xd0\xf0 \x9a\xc6v\xa6\xd3;\xcac|\x95\x94\xa0\xe6:\nM\xc53T\x0e\xf0\x89\x8e(\x0c\xb9\xf7\\\xa5\x12\x02\xa2\x02\x12\x08\x04\x7f\x00\x00\x01\x06\xc9%\x12\x08\x04\xc0\xa8\n\x87\x06\xc9%\x12\x14)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x06\xc9&',  # noqa: E501
+        b'\x08\x00"R\n"\x12 \xc3\xc3\xee\x18i\x8a\xde\x13\xa9y\x905\xeb\xcb\xa4\xd07\x14\xbe\xf4\xf8\x1b\xe8[g94\x94\xe3f\x18\xa9\x12\x02\xa2\x02\x12\x08\x04\x7f\x00\x00\x01\x06\xc9`\x12\x08\x04\xc0\xa8\n\x87\x06\xc9`\x12\x14)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x06\xc9a',  # noqa: E501
+    ),
+    # give test cases ids to prevent bytes from ruining the terminal
+    ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
+)
+@pytest.mark.asyncio
+async def test_read_pbmsg_safe_valid(msg_bytes):
+    s = MockReaderWriter()
+    await write_unsigned_varint(s, len(msg_bytes))
+    s.write(msg_bytes)
+    # reset the offset back to the beginning
+    s.seek(0, 0)
+    pb_msg = p2pd_pb.Response()
+    await read_pbmsg_safe(s, pb_msg)
+    assert pb_msg.SerializeToString() == msg_bytes
+
+
+@pytest.mark.parametrize(
+    "pb_msg, msg_bytes",
+    (
+        (
+            p2pd_pb.Response(),
+            b'Z\x08\x00*V\x08\x01\x12R\n"\x12 \x03\x8d\xf5\xd4(/#\xd6\xed\xa5\x1bU\xb8s\x8c\xfa\xad\xfc{\x04\xe3\xecw\xdeK\xc9,\xfe\x9c\x00:\xc8\x12\x02\xa2\x02\x12\x08\x04\x7f\x00\x00\x01\x06\xdea\x12\x08\x04\xc0\xa8\n\x87\x06\xdea\x12\x14)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x06\xdeb',  # noqa: E501
+        ),
+        (p2pd_pb.Request(), b"\x02\x08\x05"),
+        (
+            p2pd_pb.DHTRequest(),
+            b'&\x08\x00\x12"\x12 \xd5\x0b\x18/\x9e\xa5G\x06.\xdd\xebW\xf0N\xf5\x0eW\xd3\xec\xdf\x06\x02\xe2\x89\x1e\xf0\xbb.\xc0\xbdE\xb8',  # noqa: E501
+        ),
+        (
+            p2pd_pb.DHTResponse(),
+            b'V\x08\x01\x12R\n"\x12 wy\xe2\xfa\x11\x9e\xe2\x84X]\x84\xf8\x98\xba\x8c\x8cQ\xd7,\xb59\x1e!G\x92\x86G{\x141\xe9\x1b\x12\x02\xa2\x02\x12\x08\x04\x7f\x00\x00\x01\x06\xdeA\x12\x08\x04\xc0\xa8\n\x87\x06\xdeA\x12\x14)\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x06\xdeB',  # noqa: E501
+        ),
+        (
+            p2pd_pb.StreamInfo(),
+            b';\n"\x12 \xf6\x9e=\x9f\xc1J\xfe\x02\x93k!S\x80\xa0\xcc(s\xea&\xbe\xed\x9274qTI\xc1\xf7\xb6\xbd7\x12\x08\x04\x7f\x00\x00\x01\x06\xde\xc5\x1a\x0bprotocol123',  # noqa: E501
+        ),
+    ),
+    ids=(
+        "pb example Response",
+        "pb example Request",
+        "pb example DHTRequest",
+        "pb example DHTResponse",
+        "pb example StreamInfo",
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg(pb_msg, msg_bytes):
+    s_read = MockReaderWriter(msg_bytes)
+    await read_pbmsg_safe(s_read, pb_msg)
+    s_write = MockReaderWriter()
+    await write_pbmsg(s_write, pb_msg)
+    assert msg_bytes == s_write.getvalue()
+
+
+@pytest.mark.parametrize(
+    "pb_msg",
+    (
+        p2pd_pb.Response(),
+        p2pd_pb.Request(),
+        p2pd_pb.DHTRequest(),
+        p2pd_pb.DHTResponse(),
+        p2pd_pb.StreamInfo(),
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg_missing_fields(pb_msg):
+    with pytest.raises(EncodeError):
+        await write_pbmsg(MockReaderWriter(), pb_msg)
+
+TIMEOUT_DURATION = 30  # seconds
+
+@pytest.fixture
+def num_p2pds():
+    return 4
+
+
+@pytest.fixture(scope="module")
+def peer_id_random():
+    return ID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
+
+
+@pytest.fixture
+def enable_control():
+    return True
+
+
+@pytest.fixture
+def enable_connmgr():
+    return False
+
+
+@pytest.fixture
+def enable_dht():
+    return False
+
+
+@pytest.fixture
+def enable_pubsub():
+    return False
+
+
+@pytest.fixture
+def func_make_p2pd_pair():
+    return make_p2pd_pair_ip4
+
+
+async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
+    """
+    Keep running ``coro_func`` until the time is out.
+    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
+    """
+    t_start = time.monotonic()
+    while True:
+        result = await coro_func()
+        if result:
+            break
+        if (time.monotonic() - t_start) >= timeout:
+            # timeout
+            assert False, f"{coro_func} still failed after `{timeout}` seconds"
+        await asyncio.sleep(0.01)
+
+
+class Daemon:
+    control_maddr = None
+    proc_daemon = None
+    log_filename = ""
+    f_log = None
+    closed = None
+
+    def __init__(
+        self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
+    ):
+        self.control_maddr = control_maddr
+        self.enable_control = enable_control
+        self.enable_connmgr = enable_connmgr
+        self.enable_dht = enable_dht
+        self.enable_pubsub = enable_pubsub
+        self.is_closed = False
+        self._start_logging()
+        self._run()
+
+    def _start_logging(self):
+        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
+        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
+        self.f_log = open(self.log_filename, "wb")
+
+    def _run(self):
+        cmd_list = ["hivemind/hivemind_cli/p2pd", f"-listen={str(self.control_maddr)}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        if self.enable_connmgr:
+            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
+        if self.enable_dht:
+            cmd_list += ["-dht=true"]
+        if self.enable_pubsub:
+            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
+        self.proc_daemon = subprocess.Popen(
+            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
+        )
+
+    async def wait_until_ready(self):
+        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
+        lines_head_occurred = {line: False for line in lines_head_pattern}
+
+        with open(self.log_filename, "rb") as f_log_read:
+
+            async def read_from_daemon_and_check():
+                line = f_log_read.readline()
+                for head_pattern in lines_head_occurred:
+                    if line.startswith(head_pattern):
+                        lines_head_occurred[head_pattern] = True
+                return all([value for _, value in lines_head_occurred.items()])
+
+            await try_until_success(read_from_daemon_and_check)
+
+        # sleep for a while in case that the daemon haven't been ready after emitting these lines
+        await asyncio.sleep(0.1)
+
+    def close(self):
+        if self.is_closed:
+            return
+        self.proc_daemon.terminate()
+        self.proc_daemon.wait()
+        self.f_log.close()
+        self.is_closed = True
+
+
+class DaemonTuple(NamedTuple):
+    daemon: Daemon
+    client: Client
+
+
+class ConnectionFailure(Exception):
+    pass
+
+
+@asynccontextmanager
+async def make_p2pd_pair_unix(
+    enable_control, enable_connmgr, enable_dht, enable_pubsub
+):
+    name = str(uuid.uuid4())[:8]
+    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
+    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
+    # Remove the existing unix socket files if they are existing
+    try:
+        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    try:
+        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    async with _make_p2pd_pair(
+        control_maddr=control_maddr,
+        listen_maddr=listen_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    async with _make_p2pd_pair(
+        control_maddr=control_maddr,
+        listen_maddr=listen_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def _make_p2pd_pair(
+    control_maddr,
+    listen_maddr,
+    enable_control,
+    enable_connmgr,
+    enable_dht,
+    enable_pubsub,
+):
+    p2pd = Daemon(
+        control_maddr=control_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    )
+    # wait for daemon ready
+    await p2pd.wait_until_ready()
+    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    try:
+        async with client.listen():
+            yield DaemonTuple(daemon=p2pd, client=client)
+    finally:
+        if not p2pd.is_closed:
+            p2pd.close()
+
+
+@pytest.fixture
+async def p2pcs(
+    num_p2pds,
+    enable_control,
+    enable_connmgr,
+    enable_dht,
+    enable_pubsub,
+    func_make_p2pd_pair,
+):
+    # TODO: Change back to gather style
+    async with AsyncExitStack() as stack:
+        p2pd_tuples = [
+            await stack.enter_async_context(
+                func_make_p2pd_pair(
+                    enable_control=enable_control,
+                    enable_connmgr=enable_connmgr,
+                    enable_dht=enable_dht,
+                    enable_pubsub=enable_pubsub,
+                )
+            )
+            for _ in range(num_p2pds)
+        ]
+        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
+
+
+@pytest.mark.parametrize(
+    "enable_control, func_make_p2pd_pair", ((True, make_p2pd_pair_unix),)
+)
+@pytest.mark.asyncio
+async def test_client_identify_unix_socket(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_identify(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_connect_success(p2pcs):
+    peer_id_0, maddrs_0 = await p2pcs[0].identify()
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].connect(peer_id_1, maddrs_1)
+    # test case: repeated connections
+    await p2pcs[1].connect(peer_id_0, maddrs_0)
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_connect_failure(peer_id_random, p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].identify()
+    # test case: `peer_id` mismatches
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_random, maddrs_1)
+    # test case: empty maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [])
+    # test case: wrong maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
+
+
+async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_0, _ = await p2pd_tuple_0.identify()
+    peer_id_1, _ = await p2pd_tuple_1.identify()
+    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
+    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
+    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
+
+
+async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
+    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
+    await try_until_success(
+        functools.partial(
+            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
+        )
+    )
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_connect_safe(p2pcs):
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_list_peers(p2pcs):
+    # test case: no peers
+    assert len(await p2pcs[0].list_peers()) == 0
+    # test case: 1 peer
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    # test case: one more peer
+    await connect_safe(p2pcs[0], p2pcs[2])
+    assert len(await p2pcs[0].list_peers()) == 2
+    assert len(await p2pcs[1].list_peers()) == 1
+    assert len(await p2pcs[2].list_peers()) == 1
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_disconnect(peer_id_random, p2pcs):
+    # test case: disconnect a peer without connections
+    await p2pcs[1].disconnect(peer_id_random)
+    # test case: disconnect
+    peer_id_0, _ = await p2pcs[0].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+    # test case: disconnect twice
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_stream_open_success(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    async def handle_proto(stream_info, reader, writer):
+        await reader.readexactly(1)
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+
+    # test case: normal
+    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+    # test case: open with multiple protocols
+    stream_info, reader, writer = await p2pcs[0].stream_open(
+        peer_id_1, (proto, "another_protocol")
+    )
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_stream_open_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: `stream_open` to a peer who didn't register the protocol
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # test case: `stream_open` to a peer for a non-registered protocol
+    async def handle_proto(stream_info, reader, writer):
+        pass
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_stream_handler_success(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "protocol123"
+    bytes_to_send = b"yoyoyoyoyog"
+    # event for this test function to wait until the handler function receiving the incoming data
+    event_handler_finished = asyncio.Event()
+
+    async def handle_proto(stream_info, reader, writer):
+        nonlocal event_handler_finished
+        bytes_received = await reader.readexactly(len(bytes_to_send))
+        assert bytes_received == bytes_to_send
+        event_handler_finished.set()
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    assert proto in p2pcs[1].control.handlers
+    assert handle_proto == p2pcs[1].control.handlers[proto]
+
+    # test case: test the stream handler `handle_proto`
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # wait until the handler function starts blocking waiting for the data
+    # because we haven't sent the data, we know the handler function must still blocking waiting.
+    # get the task of the protocol handler
+    writer.write(bytes_to_send)
+
+    # wait for the handler to finish
+    writer.close()
+
+    await event_handler_finished.wait()
+
+    # test case: two streams to different handlers respectively
+    another_proto = "another_protocol123"
+    another_bytes_to_send = b"456"
+    event_another_proto = asyncio.Event()
+
+    async def handle_another_proto(stream_info, reader, writer):
+        event_another_proto.set()
+        bytes_received = await reader.readexactly(len(another_bytes_to_send))
+        assert bytes_received == another_bytes_to_send
+
+    await p2pcs[1].stream_handler(another_proto, handle_another_proto)
+    assert another_proto in p2pcs[1].control.handlers
+    assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    await event_another_proto.wait()
+
+    # we know at this moment the handler must still blocking wait
+
+    writer.write(another_bytes_to_send)
+
+    writer.close()
+
+    # test case: registering twice can override the previous registration
+    event_third = asyncio.Event()
+
+    async def handler_third(stream_info, reader, writer):
+        event_third.set()
+
+    await p2pcs[1].stream_handler(another_proto, handler_third)
+    assert another_proto in p2pcs[1].control.handlers
+    # ensure the handler is override
+    assert handler_third == p2pcs[1].control.handlers[another_proto]
+
+    await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    # ensure the overriding handler is called when the protocol is opened a stream
+    await event_third.wait()
+
+
+@pytest.mark.parametrize("enable_control", (True,))
+@pytest.mark.asyncio
+async def test_client_stream_handler_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: registered a wrong protocol name
+    async def handle_proto_correct_params(stream_info, stream):
+        pass
+
+    await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))