Ver código fonte

Implement protobuf-based stream handlers over libp2p backend (#318)

This PR implements protobuf-based stream handlers over libp2p backend (including unary-stream, stream-unary, and stream-stream). Similarly to gRPC, they can be used through the Servicer interface.
Alexander Borzunov 4 anos atrás
pai
commit
fb4813347a

+ 2 - 2
hivemind/dht/protocol.py

@@ -7,7 +7,7 @@ from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
-from hivemind.p2p import P2P, P2PContext, PeerID, Servicer
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
 from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
@@ -21,7 +21,7 @@ from hivemind.utils.timed_storage import (
 logger = get_logger(__name__)
 
 
-class DHTProtocol(Servicer):
+class DHTProtocol(ServicerBase):
     # fmt:off
     p2p: P2P
     node_id: DHTID; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
 from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
-from hivemind.p2p.servicer import Servicer
+from hivemind.p2p.servicer import ServicerBase, StubBase

+ 183 - 84
hivemind/p2p/p2p_daemon.py

@@ -1,19 +1,20 @@
 import asyncio
 import os
 import secrets
+from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
 from subprocess import Popen
-from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
 
-import google.protobuf
 from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.proto import p2pd_pb2
+from hivemind.proto.p2pd_pb2 import RPCError
+from hivemind.utils.asyncio import aiter
 from hivemind.utils.logging import get_logger
 
 logger = get_logger(__name__)
@@ -39,18 +40,19 @@ class P2P:
     use the public IPFS network (https://ipfs.io).
 
     For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
-      - `P2P.add_unary_handler` accepts a protobuf message and returns another protobuf
-      - `P2P.add_stream_handler` transfers raw data using bi-directional streaming interface
+      - `P2P.add_protobuf_handler` accepts a protobuf message and returns another protobuf
+      - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
-    To access these handlers, a P2P instance can `P2P.call_unary_handler`/`P2P.call_stream_handler`,
+    To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
     using the recipient's unique `P2P.id` and the name of the corresponding handler.
     """
 
     HEADER_LEN = 8
     BYTEORDER = "big"
-    PB_HEADER_LEN = 1
-    RESULT_MESSAGE = b"\x00"
-    ERROR_MESSAGE = b"\x01"
+    MESSAGE_MARKER = b"\x00"
+    ERROR_MARKER = b"\x01"
+    END_OF_STREAM = RPCError()
+
     DHT_MODE_MAPPING = {
         "dht": {"dht": 1},
         "dht_server": {"dhtServer": 1},
@@ -253,15 +255,6 @@ class P2P:
             writer.write(data[offset : offset + chunk_size])
         await writer.drain()
 
-    @staticmethod
-    async def send_protobuf(protobuf, writer: asyncio.StreamWriter) -> None:
-        if isinstance(protobuf, p2pd_pb2.RPCError):
-            await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
-        else:
-            await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
-
-        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
-
     @staticmethod
     async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
         header = await reader.readexactly(P2P.HEADER_LEN)
@@ -269,68 +262,192 @@ class P2P:
         data = await reader.readexactly(content_length)
         return data
 
+    TInputProtobuf = TypeVar("TInputProtobuf")
+    TOutputProtobuf = TypeVar("TOutputProtobuf")
+
+    @staticmethod
+    async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
+        if isinstance(protobuf, RPCError):
+            writer.write(P2P.ERROR_MARKER)
+        else:
+            writer.write(P2P.MESSAGE_MARKER)
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+
     @staticmethod
     async def receive_protobuf(
-        in_proto_type: type, reader: asyncio.StreamReader
-    ) -> Tuple[Any, Optional[p2pd_pb2.RPCError]]:
-        msg_type = await P2P.receive_raw_data(reader)
-        if msg_type == P2P.RESULT_MESSAGE:
-            protobuf = in_proto_type()
+        input_protobuf_type: type, reader: asyncio.StreamReader
+    ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
+        msg_type = await reader.readexactly(1)
+        if msg_type == P2P.MESSAGE_MARKER:
+            protobuf = input_protobuf_type()
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             return protobuf, None
-        elif msg_type == P2P.ERROR_MESSAGE:
-            protobuf = p2pd_pb2.RPCError()
+        elif msg_type == P2P.ERROR_MARKER:
+            protobuf = RPCError()
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             return None, protobuf
         else:
             raise TypeError("Invalid Protobuf message type")
 
-    def _handle_unary_stream(self, handler: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type):
-        async def watchdog(reader: asyncio.StreamReader) -> None:
-            await reader.read(n=1)
-            raise P2PInterruptedError()
+    TInputStream = AsyncIterator[TInputProtobuf]
+    TOutputStream = AsyncIterator[TOutputProtobuf]
+
+    async def _add_protobuf_stream_handler(
+        self,
+        name: str,
+        handler: Callable[[TInputStream, P2PContext], TOutputStream],
+        input_protobuf_type: type,
+        max_prefetch: int = 0,
+    ) -> None:
+        """
+        :param max_prefetch: Maximum number of items to prefetch from the request stream.
+          ``max_prefetch <= 0`` means unlimited (default).
+
+        :note:  Since the cancel messages are sent via the input stream,
+          they will not be received while the prefetch buffer is full.
+        """
+
+        if self._listen_task is None:
+            self._start_listening()
 
-        async def do_handle_unary_stream(
+        async def _handle_stream(
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
         ) -> None:
-            with closing(writer):
+            context = P2PContext(
+                handle_name=name,
+                local_id=self.id,
+                remote_id=stream_info.peer_id,
+                remote_maddr=stream_info.addr,
+            )
+            requests = asyncio.Queue(max_prefetch)
+
+            async def _read_stream() -> P2P.TInputStream:
+                while True:
+                    request = await requests.get()
+                    if request is None:
+                        break
+                    yield request
+
+            async def _process_stream() -> None:
                 try:
-                    request, err = await P2P.receive_protobuf(in_proto_type, reader)
-                except asyncio.IncompleteReadError:
-                    logger.debug(f"Incomplete read while receiving request from peer in {handle_name}")
-                    return
-                except google.protobuf.message.DecodeError as error:
-                    logger.debug(
-                        f"Failed to decode request protobuf " f"of type {in_proto_type} in {handle_name}: {error}"
-                    )
-                    return
-                if err is not None:
-                    logger.debug(f"Got an error instead of a request in {handle_name}: {err}")
-
-                context = P2PContext(
-                    handle_name=handle_name,
-                    local_id=self.id,
-                    remote_id=stream_info.peer_id,
-                    remote_maddr=stream_info.addr,
-                )
-                done, pending = await asyncio.wait(
-                    [watchdog(reader), handler(request, context)], return_when=asyncio.FIRST_COMPLETED
-                )
+                    async for response in handler(_read_stream(), context):
+                        await P2P.send_protobuf(response, writer)
+                except Exception as e:
+                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
+                    await P2P.send_protobuf(RPCError(message=str(e)), writer)
+
+            with closing(writer):
+                processing_task = asyncio.create_task(_process_stream())
                 try:
-                    result = done.pop().result()
-                    await P2P.send_protobuf(result, writer)
-                except P2PInterruptedError:
-                    pass
-                except Exception as exc:
-                    error = p2pd_pb2.RPCError(message=str(exc))
-                    await P2P.send_protobuf(error, writer)
+                    while True:
+                        receive_task = asyncio.create_task(P2P.receive_protobuf(input_protobuf_type, reader))
+                        await asyncio.wait({processing_task, receive_task}, return_when=asyncio.FIRST_COMPLETED)
+
+                        if processing_task.done():
+                            receive_task.cancel()
+                            return
+
+                        if receive_task.done():
+                            try:
+                                request, _ = await receive_task
+                            except asyncio.IncompleteReadError:  # Connection is closed (the client cancelled or died)
+                                return
+                            await requests.put(request)  # `request` is None for the end-of-stream message
+                except Exception:
+                    logger.warning("Exception while receiving requests:", exc_info=True)
                 finally:
-                    if pending:
-                        for task in pending:
-                            task.cancel()
-                        await asyncio.wait(pending)
+                    processing_task.cancel()
+
+        await self._client.stream_handler(name, _handle_stream)
+
+    async def _iterate_protobuf_stream_handler(
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+    ) -> TOutputStream:
+        _, reader, writer = await self._client.stream_open(peer_id, (name,))
+
+        async def _write_to_stream() -> None:
+            async for request in requests:
+                await P2P.send_protobuf(request, writer)
+            await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
+
+        with closing(writer):
+            writing_task = asyncio.create_task(_write_to_stream())
+            try:
+                while True:
+                    try:
+                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                    except asyncio.IncompleteReadError:  # Connection is closed
+                        break
+
+                    if err is not None:
+                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                    yield response
+
+                await writing_task
+            finally:
+                writing_task.cancel()
+
+    async def add_protobuf_handler(
+        self,
+        name: str,
+        handler: Callable[
+            [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
+        ],
+        input_protobuf_type: type,
+        *,
+        stream_input: bool = False,
+    ) -> None:
+        """
+        :param stream_input: If True, assume ``handler`` to take ``TInputStream``
+                             (not just ``TInputProtobuf``) as input.
+        """
 
-        return do_handle_unary_stream
+        async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
+            if stream_input:
+                input = requests
+            else:
+                count = 0
+                async for input in requests:
+                    count += 1
+                if count != 1:
+                    raise ValueError(f"Got {count} requests for handler {name} instead of one")
+
+            output = handler(input, context)
+
+            if isinstance(output, AsyncIterableABC):
+                async for item in output:
+                    yield item
+            else:
+                yield await output
+
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+
+    async def call_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        name: str,
+        input: Union[TInputProtobuf, TInputStream],
+        output_protobuf_type: type,
+    ) -> Awaitable[TOutputProtobuf]:
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        count = 0
+        async for response in responses:
+            count += 1
+        if count != 1:
+            raise ValueError(f"Got {count} responses from handler {name} instead of one")
+        return response
+
+    def iterate_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        name: str,
+        input: Union[TInputProtobuf, TInputStream],
+        output_protobuf_type: type,
+    ) -> TOutputStream:
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
     def _start_listening(self) -> None:
         async def listen() -> None:
@@ -349,34 +466,16 @@ class P2P:
                 self._listen_task = None
                 self._server_stopped.clear()
 
-    async def add_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
             self._start_listening()
         await self._client.stream_handler(name, handler)
 
-    async def add_unary_handler(
-        self, name: str, handler: Callable[[Any, P2PContext], Any], in_proto_type: type
-    ) -> None:
-        if self._listen_task is None:
-            self._start_listening()
-        await self._client.stream_handler(name, self._handle_unary_stream(handler, name, in_proto_type))
-
-    async def call_stream_handler(
+    async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
         return await self._client.stream_open(peer_id, (handler_name,))
 
-    async def call_unary_handler(
-        self, peer_id: PeerID, handler_name: str, request_protobuf: Any, response_proto_type: type
-    ) -> Any:
-        _, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-        with closing(writer):
-            await P2P.send_protobuf(request_protobuf, writer)
-            result, err = await P2P.receive_protobuf(response_proto_type, reader)
-            if err is not None:
-                raise P2PHandlerError(f"Failed to call unary handler {handler_name} at {peer_id}: {err.message}")
-            return result
-
     def __del__(self):
         self._terminate()
 

+ 50 - 30
hivemind/p2p/servicer.py

@@ -1,7 +1,6 @@
 import asyncio
-import importlib
 from dataclasses import dataclass
-from typing import Any, Optional, Union
+from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
 
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
@@ -13,6 +12,8 @@ class RPCHandler:
     handle_name: str
     request_type: type
     response_type: type
+    stream_input: bool
+    stream_output: bool
 
 
 class StubBase:
@@ -28,11 +29,11 @@ class StubBase:
         self._peer = peer
 
 
-class Servicer:
+class ServicerBase:
     """
     Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
 
-    - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P unary handlers, allowing
+    - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value
       to infer protobufs the methods operate with.
 
@@ -48,18 +49,22 @@ class Servicer:
             if method_name.startswith("rpc_") and callable(method):
                 handle_name = f"{class_name}.{method_name}"
 
-                hints = method.__annotations__
+                hints = get_type_hints(method)
                 try:
-                    request_type = self._hint_to_type(hints["request"])
-                    response_type = self._hint_to_type(hints["return"])
-                except (KeyError, ValueError):
+                    request_type = hints["request"]
+                    response_type = hints["return"]
+                except KeyError:
                     raise ValueError(
-                        f"{handle_name} is expected to have type annotations like `dht_pb2.FindRequest` "
-                        f"(a type from the hivemind.proto module) for the `request` parameter "
-                        f"and the return value"
+                        f"{handle_name} is expected to have type annotations "
+                        f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
+                        f"for the `request` parameter and the return value"
                     )
+                request_type, stream_input = self._strip_iterator_hint(request_type)
+                response_type, stream_output = self._strip_iterator_hint(response_type)
 
-                self._rpc_handlers.append(RPCHandler(method_name, handle_name, request_type, response_type))
+                self._rpc_handlers.append(
+                    RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
+                )
 
         self._stub_type = type(
             f"{class_name}Stub",
@@ -69,14 +74,33 @@ class Servicer:
 
     @staticmethod
     def _make_rpc_caller(handler: RPCHandler):
+        input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+
         # This method will be added to a new Stub type (a subclass of StubBase)
-        async def caller(
-            self: StubBase, request: handler.request_type, timeout: Optional[float] = None
-        ) -> handler.response_type:
-            return await asyncio.wait_for(
-                self._p2p.call_unary_handler(self._peer, handler.handle_name, request, handler.response_type),
-                timeout=timeout,
-            )
+        if handler.stream_output:
+
+            def caller(
+                self: StubBase, input: input_type, timeout: None = None
+            ) -> AsyncIterator[handler.response_type]:
+                if timeout is not None:
+                    raise ValueError("Timeouts for handlers returning streams are not supported")
+
+                return self._p2p.iterate_protobuf_handler(
+                    self._peer,
+                    handler.handle_name,
+                    input,
+                    handler.response_type,
+                )
+
+        else:
+
+            async def caller(
+                self: StubBase, input: input_type, timeout: Optional[float] = None
+            ) -> handler.response_type:
+                return await asyncio.wait_for(
+                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
+                    timeout=timeout,
+                )
 
         caller.__name__ = handler.method_name
         return caller
@@ -84,23 +108,19 @@ class Servicer:
     async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
-            await p2p.add_unary_handler(
+            await p2p.add_protobuf_handler(
                 handler.handle_name,
                 getattr(servicer, handler.method_name),
                 handler.request_type,
+                stream_input=handler.stream_input,
             )
 
     def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
         return self._stub_type(p2p, peer)
 
     @staticmethod
-    def _hint_to_type(hint: Union[type, str]) -> type:
-        if isinstance(hint, type):
-            return hint
-
-        module_name, proto_name = hint.split(".")
-        module = importlib.import_module("hivemind.proto." + module_name)
-        result = getattr(module, proto_name)
-        if not isinstance(result, type):
-            raise ValueError(f"`hivemind.proto.{hint}` is not a type")
-        return result
+    def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
+        if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):
+            return hint.__args__[0], True
+
+        return hint, False

+ 1 - 1
hivemind/proto/p2pd.proto

@@ -162,5 +162,5 @@ message PSResponse {
 }
 
 message RPCError {
-  required string message = 1;
+  optional string message = 1;
 }

+ 9 - 0
hivemind/proto/test.proto

@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+message TestRequest {
+    int32 number = 1;
+}
+
+message TestResponse {
+    int32 number = 1;
+}

+ 27 - 22
tests/test_p2p_daemon.py

@@ -81,7 +81,7 @@ async def test_daemon_replica_does_not_affect_primary():
     ],
 )
 @pytest.mark.asyncio
-async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
+async def test_call_protobuf_handler(should_cancel, replicate, handle_name="handle"):
     handler_cancelled = False
     server_primary = await P2P.create()
     server = await replicate_if_needed(server_primary, replicate)
@@ -95,7 +95,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     server_pid = server_primary._child.pid
-    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest)
+    await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
 
     client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
@@ -108,14 +108,19 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
     if should_cancel:
-        *_, writer = await client.call_stream_handler(server.id, handle_name)
-        with closing(writer):
-            await P2P.send_protobuf(ping_request, writer)
+        call_task = asyncio.create_task(
+            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        )
+        await asyncio.sleep(0.25)
+
+        call_task.cancel()
 
-        await asyncio.sleep(1)
+        await asyncio.sleep(0.25)
         assert handler_cancelled
     else:
-        actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        actual_response = await client.call_protobuf_handler(
+            server.id, handle_name, ping_request, dht_pb2.PingResponse
+        )
         assert actual_response == expected_response
         assert not handler_cancelled
 
@@ -128,13 +133,13 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
 
 
 @pytest.mark.asyncio
-async def test_call_unary_handler_error(handle_name="handle"):
+async def test_call_protobuf_handler_error(handle_name="handle"):
     async def error_handler(request, context):
         raise ValueError("boom")
 
     server = await P2P.create()
     server_pid = server._child.pid
-    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest)
+    await server.add_protobuf_handler(handle_name, error_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
@@ -145,7 +150,7 @@ async def test_call_unary_handler_error(handle_name="handle"):
     ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
 
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
 
     await server.shutdown()
@@ -183,7 +188,7 @@ async def test_call_peer_single_process():
     assert is_process_running(server_pid)
 
     handler_name = "square"
-    await server.add_stream_handler(handler_name, handle_square_stream)
+    await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
@@ -191,7 +196,7 @@ async def test_call_peer_single_process():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     await validate_square_stream(reader, writer)
 
     await server.shutdown()
@@ -206,7 +211,7 @@ async def run_server(handler_name, server_side, response_received):
     server_pid = server._child.pid
     assert is_process_running(server_pid)
 
-    await server.add_stream_handler(handler_name, handle_square_stream)
+    await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
     server_side.send(server.id)
     server_side.send(await server.get_visible_maddrs())
@@ -241,7 +246,7 @@ async def test_call_peer_different_processes():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_stream_handler(peer_id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(peer_id, handler_name)
     await validate_square_stream(reader, writer)
 
     response_received.value = 1
@@ -268,7 +273,7 @@ async def test_error_closes_connection():
     assert is_process_running(server_pid)
 
     handler_name = "handler"
-    await server.add_stream_handler(handler_name, handle_raising_error)
+    await server.add_binary_stream_handler(handler_name, handle_raising_error)
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
@@ -276,7 +281,7 @@ async def test_error_closes_connection():
 
     await client.wait_for_at_least_n_peers(1)
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -285,7 +290,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -305,19 +310,19 @@ async def test_handlers_on_different_replicas():
 
     server_primary = await P2P.create()
     server_id = server_primary.id
-    await server_primary.add_stream_handler("handle_primary", partial(handler, key=b"primary"))
+    await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
     server_replica1 = await replicate_if_needed(server_primary, True)
-    await server_replica1.add_stream_handler("handle1", partial(handler, key=b"replica1"))
+    await server_replica1.add_binary_stream_handler("handle1", partial(handler, key=b"replica1"))
 
     server_replica2 = await replicate_if_needed(server_primary, True)
-    await server_replica2.add_stream_handler("handle2", partial(handler, key=b"replica2"))
+    await server_replica2.add_binary_stream_handler("handle2", partial(handler, key=b"replica2"))
 
     client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
 
     for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
-        _, reader, writer = await client.call_stream_handler(server_id, name)
+        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
         with closing(writer):
             assert await P2P.receive_raw_data(reader) == expected_key
 
@@ -327,7 +332,7 @@ async def test_handlers_on_different_replicas():
     # Primary does not handle replicas protocols after their shutdown
 
     for name in ["handle1", "handle2"]:
-        _, reader, writer = await client.call_stream_handler(server_id, name)
+        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
         with pytest.raises(asyncio.IncompleteReadError), closing(writer):
             await P2P.receive_raw_data(reader)
 

+ 148 - 0
tests/test_p2p_servicer.py

@@ -0,0 +1,148 @@
+import asyncio
+from typing import AsyncIterator
+
+import pytest
+
+from hivemind.p2p import P2P, P2PContext, ServicerBase
+from hivemind.proto import test_pb2
+
+
+@pytest.fixture
+async def server_client():
+    server = await P2P.create()
+    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
+    yield server, client
+
+    await asyncio.gather(server.shutdown(), client.shutdown())
+
+
+@pytest.mark.asyncio
+async def test_unary_unary(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
+            return test_pb2.TestResponse(number=request.number ** 2)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
+
+
+@pytest.mark.asyncio
+async def test_stream_unary(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
+            result = 0
+            async for item in request:
+                result += item.number
+            return test_pb2.TestResponse(number=result)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
+        for i in range(10):
+            yield test_pb2.TestRequest(number=i)
+
+    assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
+
+
+@pytest.mark.asyncio
+async def test_unary_stream(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_count(
+            self, request: test_pb2.TestRequest, _: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
+            for i in range(request.number):
+                yield test_pb2.TestResponse(number=i)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    i = 0
+    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
+        assert item == test_pb2.TestResponse(number=i)
+        i += 1
+    assert i == 10
+
+
+@pytest.mark.asyncio
+async def test_stream_stream(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_powers(
+            self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
+            async for item in request:
+                yield test_pb2.TestResponse(number=item.number ** 2)
+                yield test_pb2.TestResponse(number=item.number ** 3)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
+        for i in range(10):
+            yield test_pb2.TestRequest(number=i)
+
+    i = 0
+    async for item in stub.rpc_powers(generate_requests()):
+        if i % 2 == 0:
+            assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
+        else:
+            assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
+        i += 1
+
+
+@pytest.mark.parametrize(
+    "cancel_reason",
+    ["close_connection", "close_generator"],
+)
+@pytest.mark.asyncio
+async def test_unary_stream_cancel(server_client, cancel_reason):
+    handler_cancelled = False
+
+    class ExampleServicer(ServicerBase):
+        async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
+            try:
+                yield test_pb2.TestResponse(number=request.number + 1)
+                await asyncio.sleep(2)
+                yield test_pb2.TestResponse(number=request.number + 2)
+            except asyncio.CancelledError:
+                nonlocal handler_cancelled
+                handler_cancelled = True
+                raise
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+
+    if cancel_reason == "close_connection":
+        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
+        await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
+
+        response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
+        assert response == test_pb2.TestResponse(number=11)
+        await asyncio.sleep(0.25)
+
+        writer.close()
+    elif cancel_reason == "close_generator":
+        stub = servicer.get_stub(client, server.id)
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+
+        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        await asyncio.sleep(0.25)
+
+        await iter.aclose()
+    else:
+        assert False, f"Unknown cancel_reason = `{cancel_reason}`"
+
+    await asyncio.sleep(0.25)
+    assert handler_cancelled