Răsfoiți Sursa

feat P2P: add unary handler (#197)

* Add unary handler

* Add P2PContext to unary handler parameters

Co-authored-by: Ilya Kobelev <ilya.kobellev@gmail.com>
Ilya 4 ani în urmă
părinte
comite
8d873f630f
2 a modificat fișierele cu 168 adăugiri și 41 ștergeri
  1. 107 40
      hivemind/p2p/p2p_daemon.py
  2. 61 1
      tests/test_p2p_daemon.py

+ 107 - 40
hivemind/p2p/p2p_daemon.py

@@ -1,17 +1,27 @@
 import asyncio
-import contextlib
 import copy
 from pathlib import Path
 import pickle
-import socket
 import subprocess
 import typing as tp
 import warnings
 
+import google.protobuf
 from multiaddr import Multiaddr
 import p2pclient
 from libp2p.peer.id import ID
 
+from hivemind.utils.networking import find_open_port
+
+
+class P2PContext(object):
+    def __init__(self, ours_id, ours_port, handle_name):
+        self.peer_id = None
+        self.peer_addr = None
+        self.ours_id = ours_id
+        self.ours_port = ours_port
+        self.handle_name = handle_name
+
 
 class P2P(object):
     """
@@ -26,11 +36,16 @@ class P2P(object):
     HEADER_LEN = 8
     BYTEORDER = 'big'
 
+    class IncompleteRead(Exception):
+        pass
+
+    class InterruptedError(Exception):
+        pass
+
     def __init__(self):
         self._child = None
         self._listen_task = None
         self._server_stopped = asyncio.Event()
-        self._buffer = bytearray()
 
     @classmethod
     async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
@@ -89,50 +104,108 @@ class P2P(object):
                 self._daemon_listen_port = find_open_port()
 
     @staticmethod
-    async def send_data(data, stream):
-        byte_str = pickle.dumps(data)
+    async def send_raw_data(byte_str, stream):
         request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
         await stream.send_all(request)
 
-    class IncompleteRead(Exception):
-        pass
+    @staticmethod
+    async def send_data(data, stream):
+        await P2P.send_raw_data(pickle.dumps(data), stream)
+
+    @staticmethod
+    async def send_protobuf(protobuf, out_proto_type, stream):
+        if type(protobuf) != out_proto_type:
+            error = TypeError('Unary handler returned protobuf of wrong type.')
+            await P2P.send_raw_data(pickle.dumps(error), stream)
+            raise error
+        await P2P.send_raw_data(protobuf.SerializeToString(), stream)
 
-    async def _receive_exactly(self, stream, n_bytes, max_bytes=1 << 16):
-        while len(self._buffer) < n_bytes:
-            data = await stream.receive_some(max_bytes)
+    @staticmethod
+    async def receive_exactly(stream, n_bytes, max_bytes=1 << 16):
+        buffer = bytearray()
+        while len(buffer) < n_bytes:
+            data = await stream.receive_some(min(max_bytes, n_bytes - len(buffer)))
             if len(data) == 0:
                 raise P2P.IncompleteRead()
-            self._buffer.extend(data)
-
-        result = self._buffer[:n_bytes]
-        self._buffer = self._buffer[n_bytes:]
-        return bytes(result)
+            buffer.extend(data)
+        return bytes(buffer)
 
-    async def receive_data(self, stream, max_bytes=(1 < 16)):
-        header = await self._receive_exactly(stream, P2P.HEADER_LEN)
+    @staticmethod
+    async def receive_raw_data(stream):
+        header = await P2P.receive_exactly(stream, P2P.HEADER_LEN)
         content_length = int.from_bytes(header, P2P.BYTEORDER)
-        data = await self._receive_exactly(stream, content_length)
-        return pickle.loads(data)
+        data = await P2P.receive_exactly(stream, content_length)
+        return data
 
-    def _handle_stream(self, handle):
+    @staticmethod
+    async def receive_data(stream):
+        return pickle.loads(await P2P.receive_raw_data(stream))
+
+    @staticmethod
+    async def receive_protobuf(in_proto_type, stream):
+        protobuf = in_proto_type()
+        protobuf.ParseFromString(await P2P.receive_raw_data(stream))
+        return protobuf
+
+    @staticmethod
+    def _handle_stream(handle):
         async def do_handle_stream(stream_info, stream):
             try:
-                request = await self.receive_data(stream)
+                request = await P2P.receive_data(stream)
             except P2P.IncompleteRead:
                 warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
+                await stream.close()
                 return
-            finally:
-                stream.close()
             try:
                 result = handle(request)
-                await self.send_data(result, stream)
+                await P2P.send_data(result, stream)
             except Exception as exc:
-                await self.send_data(exc, stream)
+                await P2P.send_data(exc, stream)
             finally:
                 await stream.close()
 
         return do_handle_stream
 
+    @staticmethod
+    def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
+        async def watchdog(stream):
+            await stream.receive_some(max_bytes=1)
+            raise P2P.InterruptedError()
+
+        async def do_handle_unary_stream(stream_info, stream):
+            try:
+                try:
+                    request = await P2P.receive_protobuf(in_proto_type, stream)
+                except P2P.IncompleteRead:
+                    warnings.warn("Incomplete read while receiving request from peer",
+                                  RuntimeWarning)
+                    return
+                except google.protobuf.message.DecodeError as error:
+                    warnings.warn(repr(error), RuntimeWarning)
+                    return
+
+                context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
+                done, pending = await asyncio.wait([watchdog(stream), handle(request, context)],
+                                                   return_when=asyncio.FIRST_COMPLETED)
+                try:
+                    result = done.pop().result()
+                    await P2P.send_protobuf(result, out_proto_type, stream)
+                except P2P.InterruptedError:
+                    pass
+                except Exception as exc:
+                    await P2P.send_data(exc, stream)
+                finally:
+                    pending_task = pending.pop()
+                    pending_task.cancel()
+                    try:
+                        await pending_task
+                    except asyncio.CancelledError:
+                        pass
+            finally:
+                await stream.close()
+
+        return do_handle_unary_stream
+
     def start_listening(self):
         async def listen():
             async with self._client.listen():
@@ -153,15 +226,21 @@ class P2P(object):
     async def add_stream_handler(self, name, handle):
         if self._listen_task is None:
             self.start_listening()
+        await self._client.stream_handler(name, P2P._handle_stream(handle))
 
-        await self._client.stream_handler(name, self._handle_stream(handle))
+    async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
+        if self._listen_task is None:
+            self.start_listening()
+        context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
+        await self._client.stream_handler(
+            name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
 
     async def call_peer_handler(self, peer_id, handler_name, input_data):
         libp2p_peer_id = ID.from_base58(peer_id)
         stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
         try:
-            await self.send_data(input_data, stream)
-            return await self.receive_data(stream)
+            await P2P.send_data(input_data, stream)
+            return await P2P.receive_data(stream)
         finally:
             await stream.close()
 
@@ -183,15 +262,3 @@ class P2P(object):
             for key, value in kwargs.items()
         )
         return proc_args
-
-
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM),
-                   opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
-    """ Finds a tcp port that can be occupied with a socket with *params and use *opt options """
-    try:
-        with contextlib.closing(socket.socket(*params)) as sock:
-            sock.bind(('', 0))
-            sock.setsockopt(*opt)
-            return sock.getsockname()[1]
-    except Exception:
-        raise

+ 61 - 1
tests/test_p2p_daemon.py

@@ -2,11 +2,13 @@ import asyncio
 import multiprocessing as mp
 import subprocess
 
+from libp2p.peer.id import ID
+
 import numpy as np
 import pytest
 
-import hivemind.p2p
 from hivemind.p2p import P2P
+from hivemind.proto import dht_pb2
 
 RUNNING = 'running'
 NOT_RUNNING = 'not running'
@@ -47,6 +49,64 @@ def handle_add(args):
     return result
 
 
+@pytest.mark.parametrize(
+    'should_cancel', [True, False]
+)
+@pytest.mark.asyncio
+async def test_call_unary_handler(should_cancel, handle_name="handle"):
+    handler_cancelled = False
+
+    async def ping_handler(request, context):
+        try:
+            await asyncio.sleep(2)
+        except asyncio.CancelledError:
+            nonlocal handler_cancelled
+            handler_cancelled = True
+        return dht_pb2.PingResponse(
+            peer=dht_pb2.NodeInfo(
+                node_id=context.ours_id.encode(), rpc_port=context.ours_port),
+            sender_endpoint=context.handle_name, available=True)
+
+    server = await P2P.create()
+    server_pid = server._child.pid
+    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest,
+                                   dht_pb2.PingResponse)
+    assert is_process_running(server_pid)
+
+    client = await P2P.create()
+    client_pid = client._child.pid
+    assert is_process_running(client_pid)
+
+    ping_request = dht_pb2.PingRequest(
+        peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port),
+        validate=True)
+    expected_response = dht_pb2.PingResponse(
+        peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port),
+        sender_endpoint=handle_name, available=True)
+
+    await asyncio.sleep(1)
+    libp2p_server_id = ID.from_base58(server.id)
+    stream_info, stream = await client._client.stream_open(libp2p_server_id, (handle_name,))
+
+    await P2P.send_raw_data(ping_request.SerializeToString(), stream)
+
+    if should_cancel:
+        await stream.close()
+        await asyncio.sleep(1)
+        assert handler_cancelled
+    else:
+        result = await P2P.receive_protobuf(dht_pb2.PingResponse, stream)
+        assert result == expected_response
+        assert not handler_cancelled
+
+    await server.stop_listening()
+    server.__del__()
+    assert not is_process_running(server_pid)
+
+    client.__del__()
+    assert not is_process_running(client_pid)
+
+
 @pytest.mark.parametrize(
     "test_input,handle",
     [