瀏覽代碼

Add initial support for connecting via libp2p (#238)

* added hivemind.P2P that wraps go-libp2p
* moved pythonic libp2p daemon bindings to hivemind.p2p
* implemented add_unary/stream_handler API for p2p communication
* added configuration options for NAT traversal and circuit relays
* added functionality tests for hivemind.P2P

Co-authored-by: Maxim Kashirin <ksh.max@gmail.com>
Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Ilya Kobelev <ilya.kobellev@gmail.com>
Co-authored-by: Alexey Bukhtiyarov <a.bukhtiyarov@yandex.ru>
Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: Michael Diskin <yhn112@users.noreply.github.com>
MaximKsh 4 年之前
父節點
當前提交
aea7a387b5

+ 18 - 0
.circleci/config.yml

@@ -1,5 +1,10 @@
 version: 2.1
 
+parameters:
+  go-version:
+    type: string
+    default: 1.16.2
+
 jobs:
   build-and-test-py37:
     docker:
@@ -9,6 +14,11 @@ jobs:
       - restore_cache:
           keys:
             - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+            - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+      - run: |
+          wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz
+          tar -C ~/ -xzf go.tar.gz
+          echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
@@ -29,6 +39,10 @@ jobs:
       - restore_cache:
           keys:
             - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+      - run: |
+          wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz
+          tar -C ~/ -xzf go.tar.gz
+          echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:
@@ -49,6 +63,10 @@ jobs:
       - restore_cache:
           keys:
             - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }}
+      - run: |
+          wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz
+          tar -C ~/ -xzf go.tar.gz
+          echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV
       - run: pip install -r requirements.txt
       - run: pip install -r requirements-dev.txt
       - save_cache:

+ 3 - 0
.gitignore

@@ -78,3 +78,6 @@ debian/files
 
 # protobuf stuff
 hivemind/proto/*_pb2*
+
+# libp2p-daemon binary
+hivemind/hivemind_cli/p2pd

+ 1 - 0
hivemind/__init__.py

@@ -1,5 +1,6 @@
 from hivemind.client import *
 from hivemind.dht import *
+from hivemind.p2p import *
 from hivemind.server import *
 from hivemind.utils import *
 from hivemind.optim import *

+ 1 - 0
hivemind/p2p/__init__.py

@@ -0,0 +1 @@
+from hivemind.p2p.p2p_daemon import P2P

+ 369 - 0
hivemind/p2p/p2p_daemon.py

@@ -0,0 +1,369 @@
+import asyncio
+from copy import deepcopy
+from dataclasses import dataclass
+from importlib.resources import path
+from subprocess import Popen
+from typing import List, Optional
+
+import google.protobuf
+from multiaddr import Multiaddr
+
+import hivemind.hivemind_cli as cli
+import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo
+from hivemind.proto import p2pd_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.logging import get_logger
+from hivemind.utils.networking import find_open_port
+
+logger = get_logger(__name__)
+
+
+P2PD_FILENAME = 'p2pd'
+NUM_RETRIES = 3
+RETRY_DELAY = 0.4
+
+
+class P2PInterruptedError(Exception):
+    pass
+
+
+@dataclass(frozen=False)
+class P2PContext(object):
+    id: str
+    port: int
+    handle_name: str
+    peer_id: PeerID = None
+    peer_addr: Multiaddr = None
+
+
+class P2P:
+    """
+    Forks a child process and executes p2pd command with given arguments.
+    Can be used for peer to peer communication and procedure calls.
+    Sends SIGKILL to the child in destructor.
+    """
+
+    HEADER_LEN = 8
+    BYTEORDER = 'big'
+    PB_HEADER_LEN = 1
+    RESULT_MESSAGE = b'\x00'
+    ERROR_MESSAGE = b'\x01'
+    DHT_MODE_MAPPING = {
+        'dht': {'dht': 1},
+        'dht_server': {'dhtServer': 1},
+        'dht_client': {'dhtClient': 1},
+    }
+    FORCE_REACHABILITY_MAPPING = {
+        'public': {'forceReachabilityPublic': 1},
+        'private': {'forceReachabilityPrivate': 1},
+    }
+
+    def __init__(self):
+        self._child = None
+        self._alive = False
+        self._listen_task = None
+        self._server_stopped = asyncio.Event()
+
+    @classmethod
+    async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True,
+                     dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
+                     nat_port_map: bool = True, auto_nat: bool = True, bootstrap: bool = False,
+                     bootstrap_peers: Optional[List[str]] = None, use_global_ipfs: bool = False, host_port: int = None,
+                     daemon_listen_port: int = None, **kwargs):
+        """
+        Start a new p2pd process and connect to it.
+        :param args:
+        :param quic: Enables the QUIC transport
+        :param tls: Enables TLS1.3 channel security protocol
+        :param conn_manager: Enables the Connection Manager
+        :param dht_mode: DHT mode (dht_client/dht_server/dht)
+        :param force_reachability: Force reachability mode (public/private)
+        :param nat_port_map: Enables NAT port mapping
+        :param auto_nat: Enables the AutoNAT service
+        :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
+        :param bootstrap_peers: List of bootstrap peers; defaults to the IPFS DHT peers
+        :param use_global_ipfs: Bootstrap to global ipfs (works only if bootstrap=True and bootstrap_peers=None)
+        :param host_port: port for p2p network
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param kwargs:
+        :return: new wrapper for p2p daemon
+        """
+
+        assert not (bootstrap and bootstrap_peers is None and not use_global_ipfs), \
+            'Trying to create with bootstrap node without bootstrap nodes list. ' \
+            'It is very dangerous, because p2pd connects to global ipfs and it is very unstable. ' \
+            'If you really want this, pass use_global_ipfs=True'
+        assert not (bootstrap_peers is not None and use_global_ipfs), \
+            'Non empty bootstrap_nodes and use_global_ipfs=True are incompatible.' \
+            'Choose one option: your nodes list (preferable) or global ipfs (very unstable)'
+
+        self = cls()
+        with path(cli, P2PD_FILENAME) as p:
+            p2pd_path = p
+        bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
+        dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
+        force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
+        proc_args = self._make_process_args(
+            str(p2pd_path), *args,
+            quic=quic, tls=tls, connManager=conn_manager,
+            natPortMap=nat_port_map, autonat=auto_nat,
+            b=bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs})
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+
+        for try_count in range(NUM_RETRIES):
+            try:
+                self._initialize(proc_args)
+                await self._wait_for_client(RETRY_DELAY * (2 ** try_count))
+                break
+            except Exception as e:
+                logger.debug(f"Failed to initialize p2p daemon: {e}")
+                self._terminate()
+                if try_count == NUM_RETRIES - 1:
+                    raise
+                self._assign_daemon_ports()
+
+        return self
+
+    @classmethod
+    async def replicate(cls, daemon_listen_port: int, host_port: int):
+        """
+        Connect to existing p2p daemon
+        :param daemon_listen_port: port for connection daemon and client binding
+        :param host_port: port for p2p network
+        :return: new wrapper for existing p2p daemon
+        """
+
+        self = cls()
+        # There is no child under control
+        # Use external already running p2pd
+        self._child = None
+        self._alive = True
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+        await self._wait_for_client()
+        return self
+
+    async def wait_for_at_least_n_peers(self, n_peers, attempts=3, delay=1):
+        for _ in range(attempts):
+            peers = await self._client.list_peers()
+            if len(peers) >= n_peers:
+                return
+            await asyncio.sleep(delay)
+
+        raise RuntimeError('Not enough peers')
+
+    def _initialize(self, proc_args: List[str]) -> None:
+        proc_args = deepcopy(proc_args)
+        proc_args.extend(self._make_process_args(
+            hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
+            listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
+        ))
+        self._child = Popen(args=proc_args, encoding="utf8")
+        self._alive = True
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+
+    async def _wait_for_client(self, delay=0):
+        await asyncio.sleep(delay)
+        encoded = await self._client.identify()
+        self.id = encoded[0].to_base58()
+
+    def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
+        if host_port is None:
+            host_port = find_open_port()
+        if daemon_listen_port is None:
+            daemon_listen_port = find_open_port()
+            while daemon_listen_port == host_port:
+                daemon_listen_port = find_open_port()
+
+        self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
+
+    @staticmethod
+    async def send_raw_data(byte_str, writer):
+        request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
+        writer.write(request)
+
+    @staticmethod
+    async def send_msgpack(data, writer):
+        raw_data = MSGPackSerializer.dumps(data)
+        await P2P.send_raw_data(raw_data, writer)
+
+    @staticmethod
+    async def send_protobuf(protobuf, out_proto_type, writer):
+        if type(protobuf) != out_proto_type:
+            raise TypeError('Unary handler returned protobuf of wrong type.')
+        if out_proto_type == p2pd_pb2.RPCError:
+            await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
+        else:
+            await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
+
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+
+    @staticmethod
+    async def receive_raw_data(reader: asyncio.StreamReader, header_len=HEADER_LEN):
+        header = await reader.readexactly(header_len)
+        content_length = int.from_bytes(header, P2P.BYTEORDER)
+        data = await reader.readexactly(content_length)
+        return data
+
+    @staticmethod
+    async def receive_msgpack(reader):
+        return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
+
+    @staticmethod
+    async def receive_protobuf(in_proto_type, reader):
+        msg_type = await P2P.receive_raw_data(reader)
+        if msg_type == P2P.RESULT_MESSAGE:
+            protobuf = in_proto_type()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return protobuf, None
+        elif msg_type == P2P.ERROR_MESSAGE:
+            protobuf = p2pd_pb2.RPCError()
+            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            return None, protobuf
+        else:
+            raise TypeError('Invalid Protobuf message type')
+
+    @staticmethod
+    def _handle_stream(handle):
+        async def do_handle_stream(stream_info, reader, writer):
+            try:
+                request = await P2P.receive_raw_data(reader)
+            except asyncio.IncompleteReadError:
+                logger.debug("Incomplete read while receiving request from peer")
+                writer.close()
+                return
+            try:
+                result = handle(request)
+                await P2P.send_raw_data(result, writer)
+            finally:
+                writer.close()
+
+        return do_handle_stream
+
+    @staticmethod
+    def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
+        async def watchdog(reader: asyncio.StreamReader):
+            await reader.read(n=1)
+            raise P2PInterruptedError()
+
+        async def do_handle_unary_stream(
+                stream_info: StreamInfo,
+                reader: asyncio.StreamReader,
+                writer: asyncio.StreamWriter) -> None:
+            try:
+                try:
+                    request = await P2P.receive_protobuf(in_proto_type, reader)
+                except asyncio.IncompleteReadError:
+                    logger.debug("Incomplete read while receiving request from peer")
+                    return
+                except google.protobuf.message.DecodeError as error:
+                    logger.exception(error)
+                    return
+
+                context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
+                done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
+                                                   return_when=asyncio.FIRST_COMPLETED)
+                try:
+                    result = done.pop().result()
+                    await P2P.send_protobuf(result, out_proto_type, writer)
+                except P2PInterruptedError:
+                    pass
+                except Exception as exc:
+                    error = p2pd_pb2.RPCError(message=str(exc))
+                    await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
+                finally:
+                    pending_task = pending.pop()
+                    pending_task.cancel()
+                    try:
+                        await pending_task
+                    except asyncio.CancelledError:
+                        pass
+            finally:
+                writer.close()
+
+        return do_handle_unary_stream
+
+    def start_listening(self):
+        async def listen():
+            async with self._client.listen():
+                await self._server_stopped.wait()
+
+        self._listen_task = asyncio.create_task(listen())
+
+    async def stop_listening(self):
+        if self._listen_task is not None:
+            self._server_stopped.set()
+            self._listen_task.cancel()
+            try:
+                await self._listen_task
+            except asyncio.CancelledError:
+                self._listen_task = None
+                self._server_stopped.clear()
+
+    async def add_stream_handler(self, name, handle):
+        if self._listen_task is None:
+            self.start_listening()
+        await self._client.stream_handler(name, self._handle_stream(handle))
+
+    async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
+        if self._listen_task is None:
+            self.start_listening()
+        context = P2PContext(id=self.id, port=self._host_port, handle_name=name)
+        await self._client.stream_handler(
+            name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
+
+    async def call_peer_handler(self, peer_id, handler_name, input_data):
+        libp2p_peer_id = PeerID.from_base58(peer_id)
+        stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+        try:
+            await P2P.send_raw_data(input_data, writer)
+            return await P2P.receive_raw_data(reader)
+        finally:
+            writer.close()
+
+    def __del__(self):
+        self._terminate()
+
+    @property
+    def is_alive(self):
+        return self._alive
+
+    async def shutdown(self):
+        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+
+    def _terminate(self):
+        self._alive = False
+        if self._child is not None and self._child.poll() is None:
+            self._child.kill()
+            self._child.wait()
+
+    @staticmethod
+    def _make_process_args(*args, **kwargs) -> List[str]:
+        proc_args = []
+        proc_args.extend(
+            str(entry) for entry in args
+        )
+        proc_args.extend(
+            f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
+            for key, value in kwargs.items()
+        )
+        return proc_args
+
+    @staticmethod
+    def _convert_process_arg_type(val):
+        if isinstance(val, bool):
+            return 1 if val else 0
+        return val
+
+    @staticmethod
+    def _make_bootstrap_peers(nodes):
+        if nodes is None:
+            return {}
+        return {'bootstrapPeers': ','.join(nodes)}

+ 0 - 0
hivemind/p2p/p2p_daemon_bindings/__init__.py


+ 210 - 0
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -0,0 +1,210 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import (AsyncIterator, Awaitable, Callable, Dict, Iterable,
+                    Sequence, Tuple)
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+from hivemind.p2p.p2p_daemon_bindings.utils import (DispatchFailure,
+                                                    raise_if_failed,
+                                                    read_pbmsg_safe,
+                                                    write_pbmsg)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from hivemind.utils.logging import get_logger
+
+StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
+
+SUPPORT_CONN_PROTOCOLS = (
+    protocols.P_IP4,
+    # protocols.P_IP6,
+    protocols.P_UNIX,
+)
+SUPPORTED_PROTOS = (
+    protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS
+)
+logger = get_logger(__name__)
+
+
+def parse_conn_protocol(maddr: Multiaddr) -> int:
+    proto_codes = set(proto.code for proto in maddr.protocols())
+    proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
+    if len(proto_cand) != 1:
+        raise ValueError(
+            f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}"
+            f", maddr={maddr}"
+        )
+    return tuple(proto_cand)[0]
+
+
+class DaemonConnector:
+    DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
+
+    def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
+        self.control_maddr = control_maddr
+        self.proto_code = parse_conn_protocol(self.control_maddr)
+
+    async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        if self.proto_code == protocols.P_UNIX:
+            control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
+            logger.debug(f"DaemonConnector {self} opens connection to {self.control_maddr}")
+            return await asyncio.open_unix_connection(control_path)
+        elif self.proto_code == protocols.P_IP4:
+            host = self.control_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
+            return await asyncio.open_connection(host, port)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}"
+            )
+
+
+class ControlClient:
+    DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
+
+    def __init__(
+            self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+    ) -> None:
+        self.listen_maddr = listen_maddr
+        self.daemon_connector = daemon_connector
+        self.handlers: Dict[str, StreamHandler] = {}
+
+    async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
+        pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
+        await read_pbmsg_safe(reader, pb_stream_info)
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+        logger.debug(f"New incoming stream: {stream_info}")
+        try:
+            handler = self.handlers[stream_info.proto]
+        except KeyError as e:
+            # should never enter here... daemon should reject the stream for us.
+            writer.close()
+            raise DispatchFailure(e)
+        await handler(stream_info, reader, writer)
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["ControlClient"]:
+        proto_code = parse_conn_protocol(self.listen_maddr)
+        if proto_code == protocols.P_UNIX:
+            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
+            server = await asyncio.start_unix_server(self._handler, path=listen_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
+            server = await asyncio.start_server(self._handler, port=port, host=host)
+        else:
+            raise ValueError(
+                f"Protocol not supported: {protocols.protocol_with_code(proto_code)}"
+            )
+
+        async with server:
+            logger.info(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
+            yield self
+
+        logger.info(f"DaemonConnector {self} closed")
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        reader, writer = await self.daemon_connector.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+
+        raise_if_failed(resp)
+        peer_id_bytes = resp.identify.id
+        maddrs_bytes = resp.identify.addrs
+
+        maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
+        peer_id = PeerID(peer_id_bytes)
+
+        return peer_id, maddrs
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        maddrs_bytes = [i.to_bytes() for i in maddrs]
+        connect_req = p2pd_pb.ConnectRequest(
+            peer=peer_id.to_bytes(), addrs=maddrs_bytes
+        )
+        req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
+        return peers
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req
+        )
+        reader, writer = await self.daemon_connector.open_connection()
+        await write_pbmsg(writer, req)
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        stream_open_req = p2pd_pb.StreamOpenRequest(
+            peer=peer_id.to_bytes(), proto=list(protocols)
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        raise_if_failed(resp)
+
+        pb_stream_info = resp.streamInfo
+        stream_info = StreamInfo.from_protobuf(pb_stream_info)
+
+        return stream_info, reader, writer
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        reader, writer = await self.daemon_connector.open_connection()
+
+        listen_path_maddr_bytes = self.listen_maddr.to_bytes()
+        stream_handler_req = p2pd_pb.StreamHandlerRequest(
+            addr=listen_path_maddr_bytes, proto=[proto]
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req
+        )
+        await write_pbmsg(writer, req)
+
+        resp = p2pd_pb.Response()  # type: ignore
+        await read_pbmsg_safe(reader, resp)
+        writer.close()
+        raise_if_failed(resp)
+
+        # if success, add the handler to the dict
+        self.handlers[proto] = handler_cb

+ 170 - 0
hivemind/p2p/p2p_daemon_bindings/datastructures.py

@@ -0,0 +1,170 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import hashlib
+from typing import Any, Sequence, Union
+
+import base58
+import multihash
+from multiaddr import Multiaddr, protocols
+
+from hivemind.proto import p2pd_pb2
+
+# NOTE: On inlining...
+# See: https://github.com/libp2p/specs/issues/138
+# NOTE: enabling to be interoperable w/ the Go implementation
+ENABLE_INLINING = True
+MAX_INLINE_KEY_LENGTH = 42
+
+IDENTITY_MULTIHASH_CODE = 0x00
+
+if ENABLE_INLINING:
+
+    class IdentityHash:
+        def __init__(self) -> None:
+            self._digest = bytearray()
+
+        def update(self, input: bytes) -> None:
+            self._digest += input
+
+        def digest(self) -> bytes:
+            return self._digest
+
+    multihash.FuncReg.register(
+        IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash
+    )
+
+
+class PeerID:
+    def __init__(self, peer_id_bytes: bytes) -> None:
+        self._bytes = peer_id_bytes
+        self._xor_id = int(sha256_digest(self._bytes).hex(), 16)
+        self._b58_str = base58.b58encode(self._bytes).decode()
+
+    @property
+    def xor_id(self) -> int:
+        return self._xor_id
+
+    def to_bytes(self) -> bytes:
+        return self._bytes
+
+    def to_base58(self) -> str:
+        return self._b58_str
+
+    def __repr__(self) -> str:
+        return f"<libp2p.peer.id.ID ({self.to_base58()})>"
+
+    def __str__(self):
+        return self.to_base58()
+
+    def pretty(self):
+        return self.to_base58()
+
+    def to_string(self):
+        return self.to_base58()
+
+    def __eq__(self, other: object) -> bool:
+        if isinstance(other, str):
+            return self.to_base58() == other
+        elif isinstance(other, bytes):
+            return self._bytes == other
+        elif isinstance(other, PeerID):
+            return self._bytes == other._bytes
+        else:
+            return False
+
+    def __hash__(self) -> int:
+        return hash(self._bytes)
+
+    @classmethod
+    def from_base58(cls, base58_id: str) -> "PeerID":
+        peer_id_bytes = base58.b58decode(base58_id)
+        return cls(peer_id_bytes)
+
+
+def sha256_digest(data: Union[str, bytes]) -> bytes:
+    if isinstance(data, str):
+        data = data.encode("utf8")
+    return hashlib.sha256(data).digest()
+
+
+class StreamInfo:
+    def __init__(self, peer_id: PeerID, addr: Multiaddr, proto: str) -> None:
+        self.peer_id = peer_id
+        self.addr = addr
+        self.proto = proto
+
+    def __repr__(self) -> str:
+        return (
+            f"<StreamInfo peer_id={self.peer_id} addr={self.addr} proto={self.proto}>"
+        )
+
+    def to_protobuf(self) -> p2pd_pb2.StreamInfo:
+        pb_msg = p2pd_pb2.StreamInfo(
+            peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto
+        )
+        return pb_msg
+
+    @classmethod
+    def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo":
+        stream_info = cls(
+            peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto
+        )
+        return stream_info
+
+
+class PeerInfo:
+    def __init__(self, peer_id: PeerID, addrs: Sequence[Multiaddr]) -> None:
+        self.peer_id = peer_id
+        self.addrs = list(addrs)
+
+    def __eq__(self, other: Any) -> bool:
+        return (
+            isinstance(other, PeerInfo)
+            and self.peer_id == other.peer_id
+            and self.addrs == other.addrs
+        )
+
+    @classmethod
+    def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo":
+        peer_id = PeerID(peer_info_pb.id)
+        addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs]
+        return PeerInfo(peer_id, addrs)
+
+    def __str__(self):
+        return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}"
+
+
+class InvalidAddrError(ValueError):
+    pass
+
+
+def info_from_p2p_addr(addr: Multiaddr) -> PeerInfo:
+    if addr is None:
+        raise InvalidAddrError("`addr` should not be `None`")
+
+    parts = addr.split()
+    if not parts:
+        raise InvalidAddrError(
+            f"`parts`={parts} should at least have a protocol `P_P2P`"
+        )
+
+    p2p_part = parts[-1]
+    last_protocol_code = p2p_part.protocols()[0].code
+    if last_protocol_code != protocols.P_P2P:
+        raise InvalidAddrError(
+            f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`"
+        )
+
+    # make sure the /p2p value parses as a peer.ID
+    peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P)
+    peer_id = PeerID.from_base58(peer_id_str)
+
+    # we might have received just an / p2p part, which means there's no addr.
+    if len(parts) > 1:
+        addr = Multiaddr.join(*parts[:-1])
+
+    return PeerInfo(peer_id, [addr])

+ 85 - 0
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -0,0 +1,85 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import AsyncIterator, Iterable, Sequence, Tuple
+
+from multiaddr import Multiaddr
+
+from hivemind.p2p.p2p_daemon_bindings.control import (ControlClient,
+                                                      DaemonConnector,
+                                                      StreamHandler)
+from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo,
+                                                             StreamInfo)
+
+
+class Client:
+    control: ControlClient
+
+    def __init__(
+        self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None
+    ) -> None:
+        daemon_connector = DaemonConnector(control_maddr=control_maddr)
+        self.control = ControlClient(
+            daemon_connector=daemon_connector, listen_maddr=listen_maddr
+        )
+
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["Client"]:
+        """
+        Starts to listen incoming connections for handlers registered via stream_handler.
+        :return:
+        """
+        async with self.control.listen():
+            yield self
+
+    async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
+        """
+        Get current node peer id and list of addresses
+        """
+        return await self.control.identify()
+
+    async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
+        """
+        Connect to p2p node with specified addresses and peer id.
+        :peer_id: node peer id you want connect to
+        :maddrs: node multiaddresses you want connect to. Of course, it must be reachable.
+        """
+        await self.control.connect(peer_id=peer_id, maddrs=maddrs)
+
+    async def list_peers(self) -> Tuple[PeerInfo, ...]:
+        """
+        Get list of peers that node connect to
+        """
+        return await self.control.list_peers()
+
+    async def disconnect(self, peer_id: PeerID) -> None:
+        """
+        Disconnect from node with specified peer id
+        :peer_id: node peer id you want disconnect from
+        """
+        await self.control.disconnect(peer_id=peer_id)
+
+    async def stream_open(
+        self, peer_id: PeerID, protocols: Sequence[str]
+    ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
+        """
+        Open a stream to call other peer (with peer_id) handler for specified protocols
+        :peer_id: other peer id
+        :protocols: list of protocols for other peer handling
+        :return: Returns tuple of stream info (info about connection to second peer) and reader/writer
+        """
+        return await self.control.stream_open(peer_id=peer_id, protocols=protocols)
+
+    async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
+        """
+        Register a stream handler
+        :param proto: protocols that handler serves
+        :param handler_cb: handler callback
+        :return:
+        """
+        await self.control.stream_handler(proto=proto, handler_cb=handler_cb)

+ 73 - 0
hivemind/p2p/p2p_daemon_bindings/utils.py

@@ -0,0 +1,73 @@
+"""
+Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+Licence: MIT
+Author: Kevin Mai-Husan Chia
+"""
+
+import asyncio
+
+from google.protobuf.message import Message as PBMessage
+
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+
+DEFAULT_MAX_BITS: int = 64
+
+
+class ControlFailure(Exception):
+    pass
+
+
+class DispatchFailure(Exception):
+    pass
+
+
+async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None:
+    max_int = 1 << max_bits
+    if integer < 0:
+        raise ValueError(f"negative integer: {integer}")
+    if integer >= max_int:
+        raise ValueError(f"integer too large: {integer}")
+    while True:
+        value = integer & 0x7F
+        integer >>= 7
+        if integer != 0:
+            value |= 0x80
+        byte = value.to_bytes(1, "big")
+        stream.write(byte)
+        if integer == 0:
+            break
+
+
+async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int:
+    max_int = 1 << max_bits
+    iteration = 0
+    result = 0
+    has_next = True
+    while has_next:
+        data = await stream.readexactly(1)
+        c = data[0]
+        value = c & 0x7F
+        result |= value << (iteration * 7)
+        has_next = (c & 0x80) != 0
+        iteration += 1
+        if result >= max_int:
+            raise ValueError(f"Varint overflowed: {result}")
+    return result
+
+
+def raise_if_failed(response: p2pd_pb.Response) -> None:
+    if response.type == p2pd_pb.Response.ERROR:
+        raise ControlFailure(f"Connect failed. msg={response.error.msg}")
+
+
+async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
+    size = pbmsg.ByteSize()
+    await write_unsigned_varint(stream, size)
+    msg_bytes: bytes = pbmsg.SerializeToString()
+    stream.write(msg_bytes)
+
+
+async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
+    len_msg_bytes = await read_unsigned_varint(stream)
+    msg_bytes = await stream.readexactly(len_msg_bytes)
+    pbmsg.ParseFromString(msg_bytes)

+ 166 - 0
hivemind/proto/p2pd.proto

@@ -0,0 +1,166 @@
+//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
+//Licence: MIT
+//Author: Kevin Mai-Husan Chia
+
+syntax = "proto2";
+
+package p2pclient.p2pd.pb;
+
+message Request {
+  enum Type {
+    IDENTIFY       = 0;
+    CONNECT        = 1;
+    STREAM_OPEN    = 2;
+    STREAM_HANDLER = 3;
+    DHT            = 4;
+    LIST_PEERS     = 5;
+    CONNMANAGER    = 6;
+    DISCONNECT     = 7;
+    PUBSUB         = 8;
+  }
+
+  required Type type = 1;
+
+  optional ConnectRequest connect = 2;
+  optional StreamOpenRequest streamOpen = 3;
+  optional StreamHandlerRequest streamHandler = 4;
+  optional DHTRequest dht = 5;
+  optional ConnManagerRequest connManager = 6;
+  optional DisconnectRequest disconnect = 7;
+  optional PSRequest pubsub = 8;
+}
+
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
+message IdentifyResponse {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnectRequest {
+  required bytes peer = 1;
+  repeated bytes addrs = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamOpenRequest {
+  required bytes peer = 1;
+  repeated string proto = 2;
+  optional int64 timeout = 3;
+}
+
+message StreamHandlerRequest {
+  required bytes addr = 1;
+  repeated string proto = 2;
+}
+
+message ErrorResponse {
+  required string msg = 1;
+}
+
+message StreamInfo {
+  required bytes peer = 1;
+  required bytes addr = 2;
+  required string proto = 3;
+}
+
+message DHTRequest {
+  enum Type {
+    FIND_PEER                    = 0;
+    FIND_PEERS_CONNECTED_TO_PEER = 1;
+    FIND_PROVIDERS               = 2;
+    GET_CLOSEST_PEERS            = 3;
+    GET_PUBLIC_KEY               = 4;
+    GET_VALUE                    = 5;
+    SEARCH_VALUE                 = 6;
+    PUT_VALUE                    = 7;
+    PROVIDE                      = 8;
+  }
+
+  required Type type = 1;
+  optional bytes peer = 2;
+  optional bytes cid = 3;
+  optional bytes key = 4;
+  optional bytes value = 5;
+  optional int32 count = 6;
+  optional int64 timeout = 7;
+}
+
+message DHTResponse {
+  enum Type {
+    BEGIN = 0;
+    VALUE = 1;
+    END   = 2;
+  }
+
+  required Type type = 1;
+  optional PeerInfo peer = 2;
+  optional bytes value = 3;
+}
+
+message PeerInfo {
+  required bytes id = 1;
+  repeated bytes addrs = 2;
+}
+
+message ConnManagerRequest {
+  enum Type {
+    TAG_PEER        = 0;
+    UNTAG_PEER      = 1;
+    TRIM            = 2;
+  }
+
+  required Type type = 1;
+
+  optional bytes peer = 2;
+  optional string tag = 3;
+  optional int64 weight = 4;
+}
+
+message DisconnectRequest {
+  required bytes peer = 1;
+}
+
+message PSRequest {
+  enum Type {
+    GET_TOPICS = 0;
+    LIST_PEERS = 1;
+    PUBLISH    = 2;
+    SUBSCRIBE  = 3;
+  }
+
+  required Type type = 1;
+  optional string topic = 2;
+  optional bytes data = 3;
+}
+
+message PSMessage {
+  optional bytes from_id = 1;
+  optional bytes data = 2;
+  optional bytes seqno = 3;
+  repeated string topicIDs = 4;
+  optional bytes signature = 5;
+  optional bytes key = 6;
+}
+
+message PSResponse {
+  repeated string topics = 1;
+  repeated bytes peerIDs = 2;
+}
+
+message RPCError {
+  required string message = 1;
+}

+ 2 - 0
requirements.txt

@@ -10,5 +10,7 @@ grpcio>=1.33.2
 grpcio-tools>=1.33.2
 protobuf>=3.12.2
 configargparse>=1.2.3
+multiaddr>=0.0.9
+pymultihash>=0.8.2
 cryptography>=3.4.6
 pydantic>=1.8.1

+ 67 - 6
setup.py

@@ -1,13 +1,35 @@
 import codecs
 import glob
+import hashlib
 import os
 import re
+import shlex
+import subprocess
+import tarfile
+import tempfile
+import urllib.request
 
+from packaging import version
 from pkg_resources import parse_requirements
-from setuptools import setup, find_packages
+from setuptools import find_packages, setup
 from setuptools.command.develop import develop
 from setuptools.command.install import install
 
+P2PD_VERSION = 'v0.3.1'
+P2PD_CHECKSUM = '8810097959db720208cdc9f2945804a4'
+LIBP2P_TAR_URL = f'https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz'
+
+
+here = os.path.abspath(os.path.dirname(__file__))
+
+
+def md5(fname, chunk_size=4096):
+    hash_md5 = hashlib.md5()
+    with open(fname, "rb") as f:
+        for chunk in iter(lambda: f.read(chunk_size), b""):
+            hash_md5.update(chunk)
+    return hash_md5.hexdigest()
+
 
 def proto_compile(output_path):
     import grpc_tools.protoc
@@ -28,20 +50,59 @@ def proto_compile(output_path):
             file.truncate()
 
 
-class ProtoCompileInstall(install):
+def libp2p_build_install():
+    try:
+        result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode('ascii', 'replace')
+        m = re.search(r'^go version go([\d.]+)', result)
+        v = m.group(1)
+
+        if version.parse(v) < version.parse("1.13"):
+            raise EnvironmentError(f'Newer version of go required: must be >= 1.13, found {version}')
+
+    except FileNotFoundError:
+        raise FileNotFoundError('Could not find golang installation')
+
+    with tempfile.TemporaryDirectory() as tempdir:
+        dest = os.path.join(tempdir, 'libp2p-daemon.tar.gz')
+        urllib.request.urlretrieve(LIBP2P_TAR_URL, dest)
+
+        with tarfile.open(dest, 'r:gz') as tar:
+            tar.extractall(tempdir)
+
+        result = subprocess.run(f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}',
+                                cwd=os.path.join(tempdir, f'go-libp2p-daemon-{P2PD_VERSION[1:]}', 'p2pd'), shell=True)
+
+        if result.returncode:
+            raise RuntimeError('Failed to build or install libp2p-daemon:'
+                               f' exited with status code: {result.returncode}')
+
+
+def libp2p_download_install():
+    install_path = os.path.join(here, 'hivemind', 'hivemind_cli')
+    binary_path = os.path.join(install_path, 'p2pd')
+    if 'p2pd' not in os.listdir(install_path) or md5(binary_path) != P2PD_CHECKSUM:
+        print('Downloading Peer to Peer Daemon')
+        url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd'
+        urllib.request.urlretrieve(url, binary_path)
+        os.chmod(binary_path, 0o777)
+        if md5(binary_path) != P2PD_CHECKSUM:
+            raise RuntimeError(f'Downloaded p2pd binary from {url} does not match with md5 checksum')
+
+
+class Install(install):
     def run(self):
+        libp2p_download_install()
         proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto'))
         super().run()
 
 
-class ProtoCompileDevelop(develop):
+class Develop(develop):
     def run(self):
+        libp2p_build_install()
         proto_compile(os.path.join('hivemind', 'proto'))
         super().run()
 
 
-here = os.path.abspath(os.path.dirname(__file__))
-
 with open('requirements.txt') as requirements_file:
     install_requires = list(map(str, parse_requirements(requirements_file)))
 
@@ -63,7 +124,7 @@ extras['all'] = extras['dev'] + extras['docs']
 setup(
     name='hivemind',
     version=version_string,
-    cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop},
+    cmdclass={'install': Install, 'develop': Develop},
     description='Decentralized deep learning in PyTorch',
     long_description='Decentralized deep learning in PyTorch. Built to train giant models on '
                      'thousands of volunteers across the world.',

+ 440 - 0
tests/test_p2p_daemon.py

@@ -0,0 +1,440 @@
+import asyncio
+import multiprocessing as mp
+import subprocess
+from functools import partial
+from typing import List
+
+import numpy as np
+import pytest
+import torch
+
+from hivemind.p2p import P2P
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+from hivemind.proto import dht_pb2, runtime_pb2
+from hivemind.utils import MSGPackSerializer
+from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
+
+
+def is_process_running(pid: int) -> bool:
+    return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
+
+
+async def replicate_if_needed(p2p: P2P, replicate: bool):
+    return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p
+
+
+def bootstrap_addr(host_port, id_):
+    return f'/ip4/127.0.0.1/tcp/{host_port}/p2p/{id_}'
+
+
+def bootstrap_from(daemons: List[P2P]) -> List[str]:
+    return [bootstrap_addr(d._host_port, d.id) for d in daemons]
+
+
+@pytest.mark.asyncio
+async def test_daemon_killed_on_del():
+    p2p_daemon = await P2P.create()
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+@pytest.mark.asyncio
+async def test_server_client_connection():
+    server = await P2P.create()
+    peers = await server._client.list_peers()
+    assert len(peers) == 0
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    peers = await client._client.list_peers()
+    assert len(peers) == 1
+    peers = await server._client.list_peers()
+    assert len(peers) == 1
+
+
+@pytest.mark.asyncio
+async def test_daemon_replica_does_not_affect_primary():
+    p2p_daemon = await P2P.create()
+    p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port)
+
+    child_pid = p2p_daemon._child.pid
+    assert is_process_running(child_pid)
+
+    await p2p_replica.shutdown()
+    assert is_process_running(child_pid)
+
+    await p2p_daemon.shutdown()
+    assert not is_process_running(child_pid)
+
+
+def handle_square(x):
+    x = MSGPackSerializer.loads(x)
+    return MSGPackSerializer.dumps(x ** 2)
+
+
+def handle_add(args):
+    args = MSGPackSerializer.loads(args)
+    result = args[0]
+    for i in range(1, len(args)):
+        result = result + args[i]
+    return MSGPackSerializer.dumps(result)
+
+
+def handle_square_torch(x):
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(x)
+    tensor = deserialize_torch_tensor(tensor)
+    result = tensor ** 2
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch(args):
+    args = MSGPackSerializer.loads(args)
+    tensor = runtime_pb2.Tensor()
+    tensor.ParseFromString(args[0])
+    result = deserialize_torch_tensor(tensor)
+
+    for i in range(1, len(args)):
+        tensor = runtime_pb2.Tensor()
+        tensor.ParseFromString(args[i])
+        result = result + deserialize_torch_tensor(tensor)
+
+    return serialize_torch_tensor(result).SerializeToString()
+
+
+def handle_add_torch_with_exc(args):
+    try:
+        return handle_add_torch(args)
+    except Exception:
+        return b'something went wrong :('
+
+
+@pytest.mark.parametrize(
+    'should_cancel,replicate', [
+        (True, False),
+        (True, True),
+        (False, False),
+        (False, True),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
+    handler_cancelled = False
+
+    async def ping_handler(request, context):
+        try:
+            await asyncio.sleep(2)
+        except asyncio.CancelledError:
+            nonlocal handler_cancelled
+            handler_cancelled = True
+        return dht_pb2.PingResponse(
+            peer=dht_pb2.NodeInfo(
+                node_id=context.id.encode(), rpc_port=context.port),
+            sender_endpoint=context.handle_name, available=True)
+
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    server_pid = server_primary._child.pid
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
+                                   dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+    client_pid = client_primary._child.pid
+    assert is_process_running(client_pid)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    expected_response = dht_pb2.PingResponse(
+        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
+        sender_endpoint=handle_name, available=True)
+
+    await client.wait_for_at_least_n_peers(1)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+
+    if should_cancel:
+        writer.close()
+        await asyncio.sleep(1)
+        assert handler_cancelled
+    else:
+        result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+        assert err is None
+        assert result == expected_response
+        assert not handler_cancelled
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client_primary.shutdown()
+    assert not is_process_running(client_pid)
+
+
+@pytest.mark.asyncio
+async def test_call_unary_handler_error(handle_name="handle"):
+    async def error_handler(request, context):
+        raise ValueError('boom')
+
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+    await client.wait_for_at_least_n_peers(1)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    libp2p_server_id = PeerID.from_base58(server.id)
+    stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer)
+    result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader)
+    assert result is None
+    assert err.message == 'boom'
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected,handle",
+    [
+        pytest.param(10, 100, handle_square, id="square_integer"),
+        pytest.param((1, 2), 3, handle_add, id="add_integers"),
+        pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"),
+        pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda")
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle)
+    assert is_process_running(server_pid)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(test_input)
+    result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert result == expected
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+
+async def run_server(handler_name, server_side, client_side, response_received):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle_square)
+    assert is_process_running(server_pid)
+
+    server_side.send(server.id)
+    server_side.send(server._host_port)
+    while response_received.value == 0:
+        await asyncio.sleep(0.5)
+
+    await server.stop_listening()
+    await server.shutdown()
+    assert not is_process_running(server_pid)
+
+
+def server_target(handler_name, server_side, client_side, response_received):
+    asyncio.run(run_server(handler_name, server_side, client_side, response_received))
+
+
+@pytest.mark.asyncio
+async def test_call_peer_different_processes():
+    handler_name = "square"
+    test_input = 2
+
+    server_side, client_side = mp.Pipe()
+    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
+    response_received.value = 0
+
+    proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received))
+    proc.start()
+
+    peer_id = client_side.recv()
+    peer_port = client_side.recv()
+
+    nodes = [bootstrap_addr(peer_port, peer_id)]
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    test_input_msgp = MSGPackSerializer.dumps(2)
+    result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp)
+    result = MSGPackSerializer.loads(result_msgp)
+    assert np.allclose(result, test_input ** 2)
+    response_received.value = 1
+
+    await client.shutdown()
+    assert not is_process_running(client_pid)
+
+    proc.join()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param(torch.tensor([2]), torch.tensor(4)),
+        pytest.param(
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]),
+            torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_square(test_input, expected, handler_name="handle"):
+    handle = handle_square_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = serialize_torch_tensor(test_input).SerializeToString()
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "test_input,expected",
+    [
+        pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])),
+        pytest.param(
+            [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])],
+            torch.tensor([[1.2, 1.4], [1.6, 1.8]])),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
+    handle = handle_add_torch
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+
+    nodes = bootstrap_from([server])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    result = runtime_pb2.Tensor()
+    result.ParseFromString(result_pb)
+    result = deserialize_torch_tensor(result)
+    assert torch.allclose(result, expected)
+
+    await server.stop_listening()
+    await server.shutdown()
+    await client.shutdown()
+
+
+@pytest.mark.parametrize(
+    "replicate",
+    [
+        pytest.param(False, id="primary"),
+        pytest.param(True, id="replica"),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_error(replicate, handler_name="handle"):
+    server_primary = await P2P.create()
+    server = await replicate_if_needed(server_primary, replicate)
+    await server.add_stream_handler(handler_name, handle_add_torch_with_exc)
+
+    nodes = bootstrap_from([server])
+    client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    client = await replicate_if_needed(client_primary, replicate)
+
+    await client.wait_for_at_least_n_peers(1)
+
+    inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]]
+    inp_msgp = MSGPackSerializer.dumps(inp)
+    result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
+    assert result == b'something went wrong :('
+
+    await server.stop_listening()
+    await server_primary.shutdown()
+    await client_primary.shutdown()
+
+
+@pytest.mark.asyncio
+async def test_handlers_on_different_replicas(handler_name="handle"):
+    def handler(arg, key):
+        return key
+
+    server_primary = await P2P.create(bootstrap=False)
+    server_id = server_primary.id
+    await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary'))
+
+    server_replica1 = await replicate_if_needed(server_primary, True)
+    await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1'))
+
+    server_replica2 = await replicate_if_needed(server_primary, True)
+    await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2'))
+
+    nodes = bootstrap_from([server_primary])
+    client = await P2P.create(bootstrap=True, bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    result = await client.call_peer_handler(server_id, handler_name, b'1')
+    assert result == b"primary"
+
+    result = await client.call_peer_handler(server_id, handler_name + '1', b'2')
+    assert result == b"replica1"
+
+    result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
+    assert result == b"replica2"
+
+    await server_replica1.stop_listening()
+    await server_replica2.stop_listening()
+
+    # Primary does not handle replicas protocols
+    with pytest.raises(asyncio.IncompleteReadError):
+        await client.call_peer_handler(server_id, handler_name + '1', b'')
+    with pytest.raises(asyncio.IncompleteReadError):
+        await client.call_peer_handler(server_id, handler_name + '2', b'')
+
+    await server_primary.stop_listening()
+    await server_primary.shutdown()
+    await client.shutdown()

+ 559 - 0
tests/test_p2p_daemon_bindings.py

@@ -0,0 +1,559 @@
+import asyncio
+import io
+from contextlib import AsyncExitStack
+
+import pytest
+from google.protobuf.message import EncodeError
+from multiaddr import Multiaddr, protocols
+
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe,
+                                                    read_unsigned_varint, write_pbmsg, write_unsigned_varint)
+from hivemind.proto import p2pd_pb2 as p2pd_pb
+from test_utils import make_p2pd_pair_ip4, connect_safe
+
+
+def test_raise_if_failed_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.ERROR
+    with pytest.raises(ControlFailure):
+        raise_if_failed(resp)
+
+
+def test_raise_if_failed_not_raises():
+    resp = p2pd_pb.Response()
+    resp.type = p2pd_pb.Response.OK
+    raise_if_failed(resp)
+
+
+PAIRS_INT_SERIALIZED_VALID = (
+    (0, b"\x00"),
+    (1, b"\x01"),
+    (128, b"\x80\x01"),
+    (2 ** 32, b"\x80\x80\x80\x80\x10"),
+    (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"),
+)
+
+PAIRS_INT_SERIALIZED_OVERFLOW = (
+    (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"),
+    (
+        2 ** 128,
+        b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04",
+    ),
+)
+
+PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7"
+PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5<Ip\xfe\xb4\xf5v'
+PEER_ID = PeerID(PEER_ID_BYTES)
+MADDR = Multiaddr("/unix/123")
+NUM_P2PDS = 4
+PEER_ID_RANDOM = PeerID.from_base58("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNK1")
+ENABLE_CONTROL = True
+ENABLE_CONNMGR = False
+ENABLE_DHT = False
+ENABLE_PUBSUB = False
+FUNC_MAKE_P2PD_PAIR = make_p2pd_pair_ip4
+
+
+class MockReader(io.BytesIO):
+    async def readexactly(self, n):
+        await asyncio.sleep(0)
+        return self.read(n)
+
+
+class MockWriter(io.BytesIO):
+    pass
+
+
+class MockReaderWriter(MockReader, MockWriter):
+    pass
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_write_unsigned_varint(integer, serialized_integer):
+    s = MockWriter()
+    await write_unsigned_varint(s, integer)
+    assert s.getvalue() == serialized_integer
+
+
+@pytest.mark.parametrize("integer", tuple(i[0] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_overflow(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer", (-1, -(2 ** 32), -(2 ** 64), -(2 ** 128)))
+@pytest.mark.asyncio
+async def test_write_unsigned_varint_negative(integer):
+    s = MockWriter()
+    with pytest.raises(ValueError):
+        await write_unsigned_varint(s, integer)
+
+
+@pytest.mark.parametrize("integer, serialized_integer", PAIRS_INT_SERIALIZED_VALID)
+@pytest.mark.asyncio
+async def test_read_unsigned_varint(integer, serialized_integer):
+    s = MockReader(serialized_integer)
+    result = await read_unsigned_varint(s)
+    assert result == integer
+
+
+@pytest.mark.parametrize("serialized_integer", tuple(i[1] for i in PAIRS_INT_SERIALIZED_OVERFLOW))
+@pytest.mark.asyncio
+async def test_read_unsigned_varint_overflow(serialized_integer):
+    s = MockReader(serialized_integer)
+    with pytest.raises(ValueError):
+        await read_unsigned_varint(s)
+
+
+@pytest.mark.parametrize("max_bits", (2, 31, 32, 63, 64, 127, 128))
+@pytest.mark.asyncio
+async def test_read_write_unsigned_varint_max_bits_edge(max_bits):
+    """
+    Test edge cases with different `max_bits`
+    """
+    for i in range(-3, 0):
+        integer = i + (2 ** max_bits)
+        s = MockReaderWriter()
+        await write_unsigned_varint(s, integer, max_bits=max_bits)
+        s.seek(0, 0)
+        result = await read_unsigned_varint(s, max_bits=max_bits)
+        assert integer == result
+
+
+def test_peer_id():
+    assert PEER_ID.to_bytes() == PEER_ID_BYTES
+    assert PEER_ID.to_string() == PEER_ID_STRING
+
+    peer_id_2 = PeerID.from_base58(PEER_ID_STRING)
+    assert peer_id_2.to_bytes() == PEER_ID_BYTES
+    assert peer_id_2.to_string() == PEER_ID_STRING
+    assert PEER_ID == peer_id_2
+    peer_id_3 = PeerID.from_base58("QmbmfNDEth7Ucvjuxiw3SP3E4PoJzbk7g4Ge6ZDigbCsNp")
+    assert PEER_ID != peer_id_3
+
+
+def test_stream_info():
+    proto = "123"
+    si = StreamInfo(PEER_ID, MADDR, proto)
+    assert si.peer_id == PEER_ID
+    assert si.addr == MADDR
+    assert si.proto == proto
+    pb_si = si.to_protobuf()
+    assert pb_si.peer == PEER_ID.to_bytes()
+    assert pb_si.addr == MADDR.to_bytes()
+    assert pb_si.proto == si.proto
+    si_1 = StreamInfo.from_protobuf(pb_si)
+    assert si_1.peer_id == PEER_ID
+    assert si_1.addr == MADDR
+    assert si_1.proto == proto
+
+
+def test_peer_info():
+    pi = PeerInfo(PEER_ID, [MADDR])
+    assert pi.peer_id == PEER_ID
+    assert pi.addrs == [MADDR]
+    pi_pb = p2pd_pb.PeerInfo(id=PEER_ID.to_bytes(), addrs=[MADDR.to_bytes()])
+    pi_1 = PeerInfo.from_protobuf(pi_pb)
+    assert pi.peer_id == pi_1.peer_id
+    assert pi.addrs == pi_1.addrs
+
+
+@pytest.mark.parametrize(
+    "maddr_str, expected_proto",
+    (("/unix/123", protocols.P_UNIX), ("/ip4/127.0.0.1/tcp/7777", protocols.P_IP4)),
+)
+def test_parse_conn_protocol_valid(maddr_str, expected_proto):
+    assert parse_conn_protocol(Multiaddr(maddr_str)) == expected_proto
+
+
+@pytest.mark.parametrize(
+    "maddr_str",
+    (
+        "/p2p/QmbHVEEepCi7rn7VL7Exxpd2Ci9NNB6ifvqwhsrbRMgQFP",
+        "/onion/timaq4ygg2iegci7:1234",
+    ),
+)
+def test_parse_conn_protocol_invalid(maddr_str):
+    maddr = Multiaddr(maddr_str)
+    with pytest.raises(ValueError):
+        parse_conn_protocol(maddr)
+
+
+@pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_client_ctor_control_maddr(control_maddr_str):
+    c = DaemonConnector(Multiaddr(control_maddr_str))
+    assert c.control_maddr == Multiaddr(control_maddr_str)
+
+
+def test_client_ctor_default_control_maddr():
+    c = DaemonConnector()
+    assert c.control_maddr == Multiaddr(DaemonConnector.DEFAULT_CONTROL_MADDR)
+
+
+@pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
+def test_control_client_ctor_listen_maddr(listen_maddr_str):
+    c = ControlClient(
+        daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str)
+    )
+    assert c.listen_maddr == Multiaddr(listen_maddr_str)
+
+
+def test_control_client_ctor_default_listen_maddr():
+    c = ControlClient(daemon_connector=DaemonConnector())
+    assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
+
+
+@pytest.mark.parametrize(
+    "msg_bytes",
+    (
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmT7WhTne9zBLfAgAJt9aiZ8jZ5BxJGowRubxsHYmnyzUd').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51126').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51126').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51127').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmcQFt2MFfCZ9AxzUCNrk4k7TtMdZZvAAteaA6tHpBKdrk').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51493').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51493').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51494').to_bytes()]
+            )).SerializeToString(),
+        p2pd_pb.Response(
+            type=p2pd_pb.Response.Type.OK,
+            identify=p2pd_pb.IdentifyResponse(
+                id=PeerID.from_base58('QmbWqVVoz7v9LS9ZUQAhyyfdFJY3iU8ZrUY3XQozoTA5cc').to_bytes(),
+                addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/51552').to_bytes(),
+                       Multiaddr('/ip4/192.168.10.135/tcp/51552').to_bytes(),
+                       Multiaddr('/ip6/::1/tcp/51553').to_bytes()]
+            )).SerializeToString(),
+    ),
+    # give test cases ids to prevent bytes from ruining the terminal
+    ids=("pb example Response 0", "pb example Response 1", "pb example Response 2"),
+)
+@pytest.mark.asyncio
+async def test_read_pbmsg_safe_valid(msg_bytes):
+    s = MockReaderWriter()
+    await write_unsigned_varint(s, len(msg_bytes))
+    s.write(msg_bytes)
+    # reset the offset back to the beginning
+    s.seek(0, 0)
+    pb_msg = p2pd_pb.Response()
+    await read_pbmsg_safe(s, pb_msg)
+    assert pb_msg.SerializeToString() == msg_bytes
+
+
+@pytest.mark.parametrize(
+    "pb_type, pb_msg",
+    (
+        (
+            p2pd_pb.Response,
+            p2pd_pb.Response(
+                type=p2pd_pb.Response.Type.OK,
+                dht=p2pd_pb.DHTResponse(
+                    type=p2pd_pb.DHTResponse.Type.VALUE,
+                    peer=p2pd_pb.PeerInfo(
+                        id=PeerID.from_base58('QmNaXUy78W9moQ9APCoKaTtPjLcEJPN9hRBCqErY7o2fQs').to_bytes(),
+                        addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56929').to_bytes(),
+                               Multiaddr('/ip4/192.168.10.135/tcp/56929').to_bytes(),
+                               Multiaddr('/ip6/::1/tcp/56930').to_bytes()]
+                    )
+                )
+            ),
+        ),
+        (p2pd_pb.Request, p2pd_pb.Request(type=p2pd_pb.Request.Type.LIST_PEERS)),
+        (
+            p2pd_pb.DHTRequest,
+            p2pd_pb.DHTRequest(type=p2pd_pb.DHTRequest.Type.FIND_PEER,
+                               peer=PeerID.from_base58('QmcgHMuEhqdLHDVeNjiCGU7Ds6E7xK3f4amgiwHNPKKn7R').to_bytes()),
+        ),
+        (
+            p2pd_pb.DHTResponse,
+            p2pd_pb.DHTResponse(
+                type=p2pd_pb.DHTResponse.Type.VALUE,
+                peer=p2pd_pb.PeerInfo(
+                    id=PeerID.from_base58('QmWP32GhEyXVQsLXFvV81eadDC8zQRZxZvJK359rXxLquk').to_bytes(),
+                    addrs=[Multiaddr('/p2p-circuit').to_bytes(), Multiaddr('/ip4/127.0.0.1/tcp/56897').to_bytes(),
+                           Multiaddr('/ip4/192.168.10.135/tcp/56897').to_bytes(),
+                           Multiaddr('/ip6/::1/tcp/56898').to_bytes()]
+                )
+            ),
+        ),
+        (
+            p2pd_pb.StreamInfo,
+            p2pd_pb.StreamInfo(peer=PeerID.from_base58('QmewLxB46MftfxQiunRgJo2W8nW4Lh5NLEkRohkHhJ4wW6').to_bytes(),
+                               addr=Multiaddr('/ip4/127.0.0.1/tcp/57029').to_bytes(),
+                               proto=b'protocol123'),
+        ),
+    ),
+    ids=(
+        "pb example Response",
+        "pb example Request",
+        "pb example DHTRequest",
+        "pb example DHTResponse",
+        "pb example StreamInfo",
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg(pb_type, pb_msg):
+    msg_bytes = bytes(chr(pb_msg.ByteSize()), 'utf-8') + pb_msg.SerializeToString()
+    pb_obj = pb_type()
+
+    s_read = MockReaderWriter(msg_bytes)
+    await read_pbmsg_safe(s_read, pb_obj)
+    s_write = MockReaderWriter()
+    await write_pbmsg(s_write, pb_obj)
+    assert msg_bytes == s_write.getvalue()
+
+
+@pytest.mark.parametrize(
+    "pb_msg",
+    (
+        p2pd_pb.Response(),
+        p2pd_pb.Request(),
+        p2pd_pb.DHTRequest(),
+        p2pd_pb.DHTResponse(),
+        p2pd_pb.StreamInfo(),
+    ),
+)
+@pytest.mark.asyncio
+async def test_write_pbmsg_missing_fields(pb_msg):
+    with pytest.raises(EncodeError):
+        await write_pbmsg(MockReaderWriter(), pb_msg)
+
+
+@pytest.fixture
+async def p2pcs():
+    # TODO: Change back to gather style
+    async with AsyncExitStack() as stack:
+        p2pd_tuples = [
+            await stack.enter_async_context(
+                FUNC_MAKE_P2PD_PAIR(
+                    enable_control=ENABLE_CONTROL,
+                    enable_connmgr=ENABLE_CONNMGR,
+                    enable_dht=ENABLE_DHT,
+                    enable_pubsub=ENABLE_PUBSUB,
+                )
+            )
+            for _ in range(NUM_P2PDS)
+        ]
+        yield tuple(p2pd_tuple.client for p2pd_tuple in p2pd_tuples)
+
+
+@pytest.mark.asyncio
+async def test_client_identify_unix_socket(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_identify(p2pcs):
+    await p2pcs[0].identify()
+
+
+@pytest.mark.asyncio
+async def test_client_connect_success(p2pcs):
+    peer_id_0, maddrs_0 = await p2pcs[0].identify()
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].connect(peer_id_1, maddrs_1)
+    # test case: repeated connections
+    await p2pcs[1].connect(peer_id_0, maddrs_0)
+
+
+@pytest.mark.asyncio
+async def test_client_connect_failure(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await p2pcs[0].identify()
+    # test case: `peer_id` mismatches
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(PEER_ID_RANDOM, maddrs_1)
+    # test case: empty maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [])
+    # test case: wrong maddrs
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].connect(peer_id_1, [Multiaddr("/ip4/127.0.0.1/udp/0")])
+
+
+@pytest.mark.asyncio
+async def test_connect_safe(p2pcs):
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+
+@pytest.mark.asyncio
+async def test_client_list_peers(p2pcs):
+    # test case: no peers
+    assert len(await p2pcs[0].list_peers()) == 0
+    # test case: 1 peer
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    # test case: one more peer
+    await connect_safe(p2pcs[0], p2pcs[2])
+    assert len(await p2pcs[0].list_peers()) == 2
+    assert len(await p2pcs[1].list_peers()) == 1
+    assert len(await p2pcs[2].list_peers()) == 1
+
+
+@pytest.mark.asyncio
+async def test_client_disconnect(p2pcs):
+    # test case: disconnect a peer without connections
+    await p2pcs[1].disconnect(PEER_ID_RANDOM)
+    # test case: disconnect
+    peer_id_0, _ = await p2pcs[0].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+    assert len(await p2pcs[0].list_peers()) == 1
+    assert len(await p2pcs[1].list_peers()) == 1
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+    # test case: disconnect twice
+    await p2pcs[1].disconnect(peer_id_0)
+    assert len(await p2pcs[0].list_peers()) == 0
+    assert len(await p2pcs[1].list_peers()) == 0
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_success(p2pcs):
+    peer_id_1, maddrs_1 = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    async def handle_proto(stream_info, reader, writer):
+        await reader.readexactly(1)
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+
+    # test case: normal
+    stream_info, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+    # test case: open with multiple protocols
+    stream_info, reader, writer = await p2pcs[0].stream_open(
+        peer_id_1, (proto, "another_protocol")
+    )
+    assert stream_info.peer_id == peer_id_1
+    assert stream_info.addr in maddrs_1
+    assert stream_info.proto == "123"
+    writer.close()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_open_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: `stream_open` to a peer who didn't register the protocol
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # test case: `stream_open` to a peer for a non-registered protocol
+    async def handle_proto(stream_info, reader, writer):
+        pass
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, ("another_protocol",))
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_success(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "protocol123"
+    bytes_to_send = b"yoyoyoyoyog"
+    # event for this test function to wait until the handler function receiving the incoming data
+    event_handler_finished = asyncio.Event()
+
+    async def handle_proto(stream_info, reader, writer):
+        nonlocal event_handler_finished
+        bytes_received = await reader.readexactly(len(bytes_to_send))
+        assert bytes_received == bytes_to_send
+        event_handler_finished.set()
+
+    await p2pcs[1].stream_handler(proto, handle_proto)
+    assert proto in p2pcs[1].control.handlers
+    assert handle_proto == p2pcs[1].control.handlers[proto]
+
+    # test case: test the stream handler `handle_proto`
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (proto,))
+
+    # wait until the handler function starts blocking waiting for the data
+    # because we haven't sent the data, we know the handler function must still blocking waiting.
+    # get the task of the protocol handler
+    writer.write(bytes_to_send)
+
+    # wait for the handler to finish
+    writer.close()
+
+    await event_handler_finished.wait()
+
+    # test case: two streams to different handlers respectively
+    another_proto = "another_protocol123"
+    another_bytes_to_send = b"456"
+    event_another_proto = asyncio.Event()
+
+    async def handle_another_proto(stream_info, reader, writer):
+        event_another_proto.set()
+        bytes_received = await reader.readexactly(len(another_bytes_to_send))
+        assert bytes_received == another_bytes_to_send
+
+    await p2pcs[1].stream_handler(another_proto, handle_another_proto)
+    assert another_proto in p2pcs[1].control.handlers
+    assert handle_another_proto == p2pcs[1].control.handlers[another_proto]
+
+    _, reader, writer = await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    await event_another_proto.wait()
+
+    # we know at this moment the handler must still blocking wait
+
+    writer.write(another_bytes_to_send)
+
+    writer.close()
+
+    # test case: registering twice can override the previous registration
+    event_third = asyncio.Event()
+
+    async def handler_third(stream_info, reader, writer):
+        event_third.set()
+
+    await p2pcs[1].stream_handler(another_proto, handler_third)
+    assert another_proto in p2pcs[1].control.handlers
+    # ensure the handler is override
+    assert handler_third == p2pcs[1].control.handlers[another_proto]
+
+    await p2pcs[0].stream_open(peer_id_1, (another_proto,))
+    # ensure the overriding handler is called when the protocol is opened a stream
+    await event_third.wait()
+
+
+@pytest.mark.asyncio
+async def test_client_stream_handler_failure(p2pcs):
+    peer_id_1, _ = await p2pcs[1].identify()
+    await connect_safe(p2pcs[0], p2pcs[1])
+
+    proto = "123"
+
+    # test case: registered a wrong protocol name
+    async def handle_proto_correct_params(stream_info, stream):
+        pass
+
+    await p2pcs[1].stream_handler("another_protocol", handle_proto_correct_params)
+    with pytest.raises(ControlFailure):
+        await p2pcs[0].stream_open(peer_id_1, (proto,))

+ 192 - 0
tests/test_utils/__init__.py

@@ -0,0 +1,192 @@
+import asyncio
+import functools
+import os
+import subprocess
+import time
+import uuid
+from contextlib import asynccontextmanager
+from typing import NamedTuple
+
+from multiaddr import Multiaddr, protocols
+
+from hivemind import find_open_port
+from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
+
+
+TIMEOUT_DURATION = 30  # seconds
+
+
+async def try_until_success(coro_func, timeout=TIMEOUT_DURATION):
+    """
+    Keep running ``coro_func`` until the time is out.
+    All arguments of ``coro_func`` should be filled, i.e. it should be called without arguments.
+    """
+    t_start = time.monotonic()
+    while True:
+        result = await coro_func()
+        if result:
+            break
+        if (time.monotonic() - t_start) >= timeout:
+            # timeout
+            assert False, f"{coro_func} still failed after `{timeout}` seconds"
+        await asyncio.sleep(0.01)
+
+
+class Daemon:
+    control_maddr = None
+    proc_daemon = None
+    log_filename = ""
+    f_log = None
+    closed = None
+
+    def __init__(
+            self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub
+    ):
+        self.control_maddr = control_maddr
+        self.enable_control = enable_control
+        self.enable_connmgr = enable_connmgr
+        self.enable_dht = enable_dht
+        self.enable_pubsub = enable_pubsub
+        self.is_closed = False
+        self._start_logging()
+        self._run()
+
+    def _start_logging(self):
+        name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_")
+        self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt"
+        self.f_log = open(self.log_filename, "wb")
+
+    def _run(self):
+        cmd_list = ["hivemind/hivemind_cli/p2pd", f"-listen={str(self.control_maddr)}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        if self.enable_connmgr:
+            cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
+        if self.enable_dht:
+            cmd_list += ["-dht=true"]
+        if self.enable_pubsub:
+            cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"]
+        self.proc_daemon = subprocess.Popen(
+            cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0
+        )
+
+    async def wait_until_ready(self):
+        lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:")
+        lines_head_occurred = {line: False for line in lines_head_pattern}
+
+        with open(self.log_filename, "rb") as f_log_read:
+
+            async def read_from_daemon_and_check():
+                line = f_log_read.readline()
+                for head_pattern in lines_head_occurred:
+                    if line.startswith(head_pattern):
+                        lines_head_occurred[head_pattern] = True
+                return all([value for _, value in lines_head_occurred.items()])
+
+            await try_until_success(read_from_daemon_and_check)
+
+        # sleep for a while in case that the daemon haven't been ready after emitting these lines
+        await asyncio.sleep(0.1)
+
+    def close(self):
+        if self.is_closed:
+            return
+        self.proc_daemon.terminate()
+        self.proc_daemon.wait()
+        self.f_log.close()
+        self.is_closed = True
+
+
+class DaemonTuple(NamedTuple):
+    daemon: Daemon
+    client: Client
+
+
+class ConnectionFailure(Exception):
+    pass
+
+
+@asynccontextmanager
+async def make_p2pd_pair_unix(
+        enable_control, enable_connmgr, enable_dht, enable_pubsub
+):
+    name = str(uuid.uuid4())[:8]
+    control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock")
+    listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock")
+    # Remove the existing unix socket files if they are existing
+    try:
+        os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    try:
+        os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX))
+    except FileNotFoundError:
+        pass
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    async with _make_p2pd_pair(
+            control_maddr=control_maddr,
+            listen_maddr=listen_maddr,
+            enable_control=enable_control,
+            enable_connmgr=enable_connmgr,
+            enable_dht=enable_dht,
+            enable_pubsub=enable_pubsub,
+    ) as pair:
+        yield pair
+
+
+@asynccontextmanager
+async def _make_p2pd_pair(
+        control_maddr,
+        listen_maddr,
+        enable_control,
+        enable_connmgr,
+        enable_dht,
+        enable_pubsub,
+):
+    p2pd = Daemon(
+        control_maddr=control_maddr,
+        enable_control=enable_control,
+        enable_connmgr=enable_connmgr,
+        enable_dht=enable_dht,
+        enable_pubsub=enable_pubsub,
+    )
+    # wait for daemon ready
+    await p2pd.wait_until_ready()
+    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    try:
+        async with client.listen():
+            yield DaemonTuple(daemon=p2pd, client=client)
+    finally:
+        if not p2pd.is_closed:
+            p2pd.close()
+
+
+async def _check_connection(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_0, _ = await p2pd_tuple_0.identify()
+    peer_id_1, _ = await p2pd_tuple_1.identify()
+    peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()]
+    peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()]
+    return (peer_id_0 in peers_1) and (peer_id_1 in peers_0)
+
+
+async def connect_safe(p2pd_tuple_0, p2pd_tuple_1):
+    peer_id_1, maddrs_1 = await p2pd_tuple_1.identify()
+    await p2pd_tuple_0.connect(peer_id_1, maddrs_1)
+    await try_until_success(
+        functools.partial(
+            _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1
+        )
+    )