瀏覽代碼

feat p2p_daemon: add API to call peer handle (#181)

* Extend P2P api

* Add tests for new api

* Add p2pclient dependencies

* Test P2P from different processes

* Fix typo in tests

* Add default initialization

* Fix daemon ports assignment

* Replace del with __del__ in tests

* Read from input stream with receive_exactly

Co-authored-by: Ilya Kobelev <ilya.kobellev@gmail.com>
Ilya 4 年之前
父節點
當前提交
3595c94e2d
共有 2 個文件被更改,包括 292 次插入40 次删除
  1. 171 19
      hivemind/p2p/p2p_daemon.py
  2. 121 21
      tests/test_p2p_daemon.py

+ 171 - 19
hivemind/p2p/p2p_daemon.py

@@ -1,45 +1,197 @@
+import asyncio
+import contextlib
+import copy
+from pathlib import Path
+import pickle
+import socket
 import subprocess
 import typing as tp
+import warnings
+
+from multiaddr import Multiaddr
+import p2pclient
+from libp2p.peer.id import ID
 
 
 class P2P(object):
     """
     Forks a child process and executes p2pd command with given arguments.
-    Sends SIGKILL to the child in destructor and on exit from contextmanager.
+    Can be used for peer to peer communication and procedure calls.
+    Sends SIGKILL to the child in destructor.
     """
 
-    LIBP2P_CMD = 'p2pd'
+    P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
+    NUM_RETRIES = 3
+    RETRY_DELAY = 0.4
+    HEADER_LEN = 8
+    BYTEORDER = 'big'
 
-    def __init__(self, *args, **kwargs):
-        self._child = subprocess.Popen(args=self._make_process_args(args, kwargs))
-        try:
-            stdout, stderr = self._child.communicate(timeout=0.2)
-        except subprocess.TimeoutExpired:
-            pass
-        else:
-            raise RuntimeError(f'p2p daemon exited with stderr: {stderr}')
+    def __init__(self):
+        self._child = None
+        self._listen_task = None
+        self._server_stopped = asyncio.Event()
+        self._buffer = bytearray()
 
-    def __enter__(self):
-        return self._child
+    @classmethod
+    async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
+                     nat_port_map=True, auto_nat=True, bootstrap=True,
+                     host_port: int = None, daemon_listen_port: int = None, **kwargs):
+        self = cls()
+        p2pd_path = Path(__file__).resolve().parents[1] / P2P.P2PD_RELATIVE_PATH
+        proc_args = self._make_process_args(
+            str(p2pd_path), *args,
+            quic=quic, tls=tls, connManager=conn_manager,
+            dhtClient=dht_client, natPortMap=nat_port_map,
+            autonat=auto_nat, b=bootstrap, **kwargs)
+        self._assign_daemon_ports(host_port, daemon_listen_port)
+        for try_count in range(self.NUM_RETRIES):
+            try:
+                self._initialize(proc_args)
+                await self._identify_client(P2P.RETRY_DELAY * (2 ** try_count))
+            except Exception as exc:
+                warnings.warn("Failed to initialize p2p daemon: " + str(exc), RuntimeWarning)
+                self._kill_child()
+                if try_count == P2P.NUM_RETRIES - 1:
+                    raise
+                self._assign_daemon_ports()
+                continue
+            break
+        return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self._kill_child()
+    def _initialize(self, proc_args: tp.List[str]) -> None:
+        proc_args = copy.deepcopy(proc_args)
+        proc_args.extend(self._make_process_args(
+            hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
+            listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
+        ))
+        self._child = subprocess.Popen(
+            args=proc_args,
+            stdin=subprocess.PIPE, stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE, encoding="utf8"
+        )
+        self._client_listen_port = find_open_port()
+        self._client = p2pclient.Client(
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
+            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
+
+    async def _identify_client(self, delay):
+        await asyncio.sleep(delay)
+        encoded = await self._client.identify()
+        self.id = encoded[0].to_base58()
+
+    def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
+        self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
+        if host_port is None:
+            self._host_port = find_open_port()
+        if daemon_listen_port is None:
+            self._daemon_listen_port = find_open_port()
+            while self._daemon_listen_port == self._host_port:
+                self._daemon_listen_port = find_open_port()
+
+    @staticmethod
+    async def send_data(data, stream):
+        byte_str = pickle.dumps(data)
+        request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
+        await stream.send_all(request)
+
+    class IncompleteRead(Exception):
+        pass
+
+    async def _receive_exactly(self, stream, n_bytes, max_bytes=1 << 16):
+        while len(self._buffer) < n_bytes:
+            data = await stream.receive_some(max_bytes)
+            if len(data) == 0:
+                raise P2P.IncompleteRead()
+            self._buffer.extend(data)
+
+        result = self._buffer[:n_bytes]
+        self._buffer = self._buffer[n_bytes:]
+        return bytes(result)
+
+    async def receive_data(self, stream, max_bytes=(1 < 16)):
+        header = await self._receive_exactly(stream, P2P.HEADER_LEN)
+        content_length = int.from_bytes(header, P2P.BYTEORDER)
+        data = await self._receive_exactly(stream, content_length)
+        return pickle.loads(data)
+
+    def _handle_stream(self, handle):
+        async def do_handle_stream(stream_info, stream):
+            try:
+                request = await self.receive_data(stream)
+            except P2P.IncompleteRead:
+                warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
+                return
+            finally:
+                stream.close()
+            try:
+                result = handle(request)
+                await self.send_data(result, stream)
+            except Exception as exc:
+                await self.send_data(exc, stream)
+            finally:
+                await stream.close()
+
+        return do_handle_stream
+
+    def start_listening(self):
+        async def listen():
+            async with self._client.listen():
+                await self._server_stopped.wait()
+
+        self._listen_task = asyncio.create_task(listen())
+
+    async def stop_listening(self):
+        if self._listen_task is not None:
+            self._server_stopped.set()
+            self._listen_task.cancel()
+            try:
+                await self._listen_task
+            except asyncio.CancelledError:
+                self._listen_task = None
+                self._server_stopped.clear()
+
+    async def add_stream_handler(self, name, handle):
+        if self._listen_task is None:
+            self.start_listening()
+
+        await self._client.stream_handler(name, self._handle_stream(handle))
+
+    async def call_peer_handler(self, peer_id, handler_name, input_data):
+        libp2p_peer_id = ID.from_base58(peer_id)
+        stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
+        try:
+            await self.send_data(input_data, stream)
+            return await self.receive_data(stream)
+        finally:
+            await stream.close()
 
     def __del__(self):
         self._kill_child()
 
     def _kill_child(self):
-        if self._child.poll() is None:
+        if self._child is not None and self._child.poll() is None:
             self._child.kill()
             self._child.wait()
 
-    def _make_process_args(self, args: tp.Tuple[tp.Any],
-                           kwargs: tp.Dict[str, tp.Any]) -> tp.List[str]:
-        proc_args = [self.LIBP2P_CMD]
+    def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
+        proc_args = []
         proc_args.extend(
             str(entry) for entry in args
         )
         proc_args.extend(
-            f'-{key}={str(value)}' for key, value in kwargs.items()
+            f'-{key}={value}' if value is not None else f'-{key}'
+            for key, value in kwargs.items()
         )
         return proc_args
+
+
+def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM),
+                   opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+    """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
+    try:
+        with contextlib.closing(socket.socket(*params)) as sock:
+            sock.bind(('', 0))
+            sock.setsockopt(*opt)
+            return sock.getsockname()[1]
+    except Exception:
+        raise

+ 121 - 21
tests/test_p2p_daemon.py

@@ -1,6 +1,8 @@
+import asyncio
+import multiprocessing as mp
 import subprocess
-from time import perf_counter
 
+import numpy as np
 import pytest
 
 import hivemind.p2p
@@ -23,33 +25,131 @@ def is_process_running(pid: int) -> bool:
     return subprocess.check_output(cmd, shell=True).decode('utf-8').strip() == RUNNING
 
 
-@pytest.fixture()
-def mock_p2p_class():
-    P2P.LIBP2P_CMD = "sleep"
-
-
-def test_daemon_killed_on_del(mock_p2p_class):
-    start = perf_counter()
-    p2p_daemon = P2P('10s')
+@pytest.mark.asyncio
+async def test_daemon_killed_on_del():
+    p2p_daemon = await P2P.create()
 
     child_pid = p2p_daemon._child.pid
     assert is_process_running(child_pid)
 
-    del p2p_daemon
+    p2p_daemon.__del__()
     assert not is_process_running(child_pid)
-    assert perf_counter() - start < 1
 
 
-def test_daemon_killed_on_exit(mock_p2p_class):
-    start = perf_counter()
-    with P2P('10s') as daemon:
-        child_pid = daemon.pid
-        assert is_process_running(child_pid)
+def handle_square(x):
+    return x ** 2
 
-    assert not is_process_running(child_pid)
-    assert perf_counter() - start < 1
+
+def handle_add(args):
+    result = args[0]
+    for i in range(1, len(args)):
+        result = result + args[i]
+    return result
+
+
+@pytest.mark.parametrize(
+    "test_input,handle",
+    [
+        pytest.param(10, handle_square, id="square_integer"),
+        pytest.param((1, 2), handle_add, id="add_integers"),
+        pytest.param(([1, 2, 3], [12, 13]), handle_add, id="add_lists"),
+        pytest.param(2, lambda x: x ** 3, id="lambda")
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_single_process(test_input, handle, handler_name="handle"):
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_stream_handler(handler_name, handle)
+    assert is_process_running(server_pid)
+
+    client = await P2P.create()
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    assert result == handle(test_input)
+
+    await server.stop_listening()
+    server.__del__()
+    assert not is_process_running(server_pid)
+
+    client.__del__()
+    assert not is_process_running(client_pid)
+
+
+@pytest.mark.asyncio
+async def test_call_peer_different_processes():
+    handler_name = "square"
+    test_input = np.random.randn(2, 3)
+
+    server_side, client_side = mp.Pipe()
+    response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32))
+    response_received.value = 0
+
+    async def run_server():
+        server = await P2P.create()
+        server_pid = server._child.pid
+        await server.add_stream_handler(handler_name, handle_square)
+        assert is_process_running(server_pid)
+
+        server_side.send(server.id)
+        while response_received.value == 0:
+            await asyncio.sleep(0.5)
+
+        await server.stop_listening()
+        server.__del__()
+        assert not is_process_running(server_pid)
+
+    def server_target():
+        asyncio.run(run_server())
+
+    proc = mp.Process(target=server_target)
+    proc.start()
+
+    client = await P2P.create()
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    await asyncio.sleep(1)
+    peer_id = client_side.recv()
+
+    result = await client.call_peer_handler(peer_id, handler_name, test_input)
+    assert np.allclose(result, handle_square(test_input))
+    response_received.value = 1
+
+    client.__del__()
+    assert not is_process_running(client_pid)
+
+    proc.join()
 
 
-def test_daemon_raises_on_faulty_args():
-    with pytest.raises(RuntimeError):
-        P2P(faulty='argument')
+@pytest.mark.parametrize(
+    "test_input,handle",
+    [
+        pytest.param(np.random.randn(2, 3), handle_square, id="square"),
+        pytest.param([np.random.randn(2, 3), np.random.randn(2, 3)], handle_add, id="add"),
+    ]
+)
+@pytest.mark.asyncio
+async def test_call_peer_numpy(test_input, handle, handler_name="handle"):
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle)
+    client = await P2P.create()
+
+    await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.id, handler_name, test_input)
+    assert np.allclose(result, handle(test_input))
+
+
+@pytest.mark.asyncio
+async def test_call_peer_error(handler_name="handle"):
+    server = await P2P.create()
+    await server.add_stream_handler(handler_name, handle_add)
+    client = await P2P.create()
+
+    await asyncio.sleep(1)
+    result = await client.call_peer_handler(server.id, handler_name,
+                                            [np.zeros((2, 3)), np.zeros((3, 2))])
+    assert type(result) == ValueError