Browse Source

Implement protobuf-based stream handlers over libp2p backend (#318)

This PR implements protobuf-based stream handlers over libp2p backend (including unary-stream, stream-unary, and stream-stream). Similarly to gRPC, they can be used through the Servicer interface.
Alexander Borzunov 4 years ago
parent
commit
fb4813347a

+ 2 - 2
hivemind/dht/protocol.py

@@ -7,7 +7,7 @@ from typing import Optional, List, Tuple, Dict, Sequence, Union, Collection
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.crypto import DHTRecord, RecordValidatorBase
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.routing import RoutingTable, DHTID, BinaryDHTValue, Subkey
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
 from hivemind.dht.storage import DHTLocalStorage, DictionaryDHTValue
-from hivemind.p2p import P2P, P2PContext, PeerID, Servicer
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
 from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils import get_logger, MSGPackSerializer
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
 from hivemind.utils.auth import AuthRole, AuthRPCWrapper, AuthorizerBase
@@ -21,7 +21,7 @@ from hivemind.utils.timed_storage import (
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
 
 
-class DHTProtocol(Servicer):
+class DHTProtocol(ServicerBase):
     # fmt:off
     # fmt:off
     p2p: P2P
     p2p: P2P
     node_id: DHTID; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo
     node_id: DHTID; bucket_size: int; num_replicas: int; wait_timeout: float; node_info: dht_pb2.NodeInfo

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
 from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
 from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
-from hivemind.p2p.servicer import Servicer
+from hivemind.p2p.servicer import ServicerBase, StubBase

+ 183 - 84
hivemind/p2p/p2p_daemon.py

@@ -1,19 +1,20 @@
 import asyncio
 import asyncio
 import os
 import os
 import secrets
 import secrets
+from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
 from subprocess import Popen
 from subprocess import Popen
-from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
 
 
-import google.protobuf
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind.hivemind_cli as cli
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
-from hivemind.proto import p2pd_pb2
+from hivemind.proto.p2pd_pb2 import RPCError
+from hivemind.utils.asyncio import aiter
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
@@ -39,18 +40,19 @@ class P2P:
     use the public IPFS network (https://ipfs.io).
     use the public IPFS network (https://ipfs.io).
 
 
     For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
     For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
-      - `P2P.add_unary_handler` accepts a protobuf message and returns another protobuf
-      - `P2P.add_stream_handler` transfers raw data using bi-directional streaming interface
+      - `P2P.add_protobuf_handler` accepts a protobuf message and returns another protobuf
+      - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
 
 
-    To access these handlers, a P2P instance can `P2P.call_unary_handler`/`P2P.call_stream_handler`,
+    To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
     using the recipient's unique `P2P.id` and the name of the corresponding handler.
     using the recipient's unique `P2P.id` and the name of the corresponding handler.
     """
     """
 
 
     HEADER_LEN = 8
     HEADER_LEN = 8
     BYTEORDER = "big"
     BYTEORDER = "big"
-    PB_HEADER_LEN = 1
-    RESULT_MESSAGE = b"\x00"
-    ERROR_MESSAGE = b"\x01"
+    MESSAGE_MARKER = b"\x00"
+    ERROR_MARKER = b"\x01"
+    END_OF_STREAM = RPCError()
+
     DHT_MODE_MAPPING = {
     DHT_MODE_MAPPING = {
         "dht": {"dht": 1},
         "dht": {"dht": 1},
         "dht_server": {"dhtServer": 1},
         "dht_server": {"dhtServer": 1},
@@ -253,15 +255,6 @@ class P2P:
             writer.write(data[offset : offset + chunk_size])
             writer.write(data[offset : offset + chunk_size])
         await writer.drain()
         await writer.drain()
 
 
-    @staticmethod
-    async def send_protobuf(protobuf, writer: asyncio.StreamWriter) -> None:
-        if isinstance(protobuf, p2pd_pb2.RPCError):
-            await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
-        else:
-            await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
-
-        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
-
     @staticmethod
     @staticmethod
     async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
     async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
         header = await reader.readexactly(P2P.HEADER_LEN)
         header = await reader.readexactly(P2P.HEADER_LEN)
@@ -269,68 +262,192 @@ class P2P:
         data = await reader.readexactly(content_length)
         data = await reader.readexactly(content_length)
         return data
         return data
 
 
+    TInputProtobuf = TypeVar("TInputProtobuf")
+    TOutputProtobuf = TypeVar("TOutputProtobuf")
+
+    @staticmethod
+    async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
+        if isinstance(protobuf, RPCError):
+            writer.write(P2P.ERROR_MARKER)
+        else:
+            writer.write(P2P.MESSAGE_MARKER)
+        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        in_proto_type: type, reader: asyncio.StreamReader
-    ) -> Tuple[Any, Optional[p2pd_pb2.RPCError]]:
-        msg_type = await P2P.receive_raw_data(reader)
-        if msg_type == P2P.RESULT_MESSAGE:
-            protobuf = in_proto_type()
+        input_protobuf_type: type, reader: asyncio.StreamReader
+    ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
+        msg_type = await reader.readexactly(1)
+        if msg_type == P2P.MESSAGE_MARKER:
+            protobuf = input_protobuf_type()
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             return protobuf, None
             return protobuf, None
-        elif msg_type == P2P.ERROR_MESSAGE:
-            protobuf = p2pd_pb2.RPCError()
+        elif msg_type == P2P.ERROR_MARKER:
+            protobuf = RPCError()
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             protobuf.ParseFromString(await P2P.receive_raw_data(reader))
             return None, protobuf
             return None, protobuf
         else:
         else:
             raise TypeError("Invalid Protobuf message type")
             raise TypeError("Invalid Protobuf message type")
 
 
-    def _handle_unary_stream(self, handler: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type):
-        async def watchdog(reader: asyncio.StreamReader) -> None:
-            await reader.read(n=1)
-            raise P2PInterruptedError()
+    TInputStream = AsyncIterator[TInputProtobuf]
+    TOutputStream = AsyncIterator[TOutputProtobuf]
+
+    async def _add_protobuf_stream_handler(
+        self,
+        name: str,
+        handler: Callable[[TInputStream, P2PContext], TOutputStream],
+        input_protobuf_type: type,
+        max_prefetch: int = 0,
+    ) -> None:
+        """
+        :param max_prefetch: Maximum number of items to prefetch from the request stream.
+          ``max_prefetch <= 0`` means unlimited (default).
+
+        :note:  Since the cancel messages are sent via the input stream,
+          they will not be received while the prefetch buffer is full.
+        """
+
+        if self._listen_task is None:
+            self._start_listening()
 
 
-        async def do_handle_unary_stream(
+        async def _handle_stream(
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
             stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
         ) -> None:
         ) -> None:
-            with closing(writer):
+            context = P2PContext(
+                handle_name=name,
+                local_id=self.id,
+                remote_id=stream_info.peer_id,
+                remote_maddr=stream_info.addr,
+            )
+            requests = asyncio.Queue(max_prefetch)
+
+            async def _read_stream() -> P2P.TInputStream:
+                while True:
+                    request = await requests.get()
+                    if request is None:
+                        break
+                    yield request
+
+            async def _process_stream() -> None:
                 try:
                 try:
-                    request, err = await P2P.receive_protobuf(in_proto_type, reader)
-                except asyncio.IncompleteReadError:
-                    logger.debug(f"Incomplete read while receiving request from peer in {handle_name}")
-                    return
-                except google.protobuf.message.DecodeError as error:
-                    logger.debug(
-                        f"Failed to decode request protobuf " f"of type {in_proto_type} in {handle_name}: {error}"
-                    )
-                    return
-                if err is not None:
-                    logger.debug(f"Got an error instead of a request in {handle_name}: {err}")
-
-                context = P2PContext(
-                    handle_name=handle_name,
-                    local_id=self.id,
-                    remote_id=stream_info.peer_id,
-                    remote_maddr=stream_info.addr,
-                )
-                done, pending = await asyncio.wait(
-                    [watchdog(reader), handler(request, context)], return_when=asyncio.FIRST_COMPLETED
-                )
+                    async for response in handler(_read_stream(), context):
+                        await P2P.send_protobuf(response, writer)
+                except Exception as e:
+                    logger.warning("Exception while processing stream and sending responses:", exc_info=True)
+                    await P2P.send_protobuf(RPCError(message=str(e)), writer)
+
+            with closing(writer):
+                processing_task = asyncio.create_task(_process_stream())
                 try:
                 try:
-                    result = done.pop().result()
-                    await P2P.send_protobuf(result, writer)
-                except P2PInterruptedError:
-                    pass
-                except Exception as exc:
-                    error = p2pd_pb2.RPCError(message=str(exc))
-                    await P2P.send_protobuf(error, writer)
+                    while True:
+                        receive_task = asyncio.create_task(P2P.receive_protobuf(input_protobuf_type, reader))
+                        await asyncio.wait({processing_task, receive_task}, return_when=asyncio.FIRST_COMPLETED)
+
+                        if processing_task.done():
+                            receive_task.cancel()
+                            return
+
+                        if receive_task.done():
+                            try:
+                                request, _ = await receive_task
+                            except asyncio.IncompleteReadError:  # Connection is closed (the client cancelled or died)
+                                return
+                            await requests.put(request)  # `request` is None for the end-of-stream message
+                except Exception:
+                    logger.warning("Exception while receiving requests:", exc_info=True)
                 finally:
                 finally:
-                    if pending:
-                        for task in pending:
-                            task.cancel()
-                        await asyncio.wait(pending)
+                    processing_task.cancel()
+
+        await self._client.stream_handler(name, _handle_stream)
+
+    async def _iterate_protobuf_stream_handler(
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+    ) -> TOutputStream:
+        _, reader, writer = await self._client.stream_open(peer_id, (name,))
+
+        async def _write_to_stream() -> None:
+            async for request in requests:
+                await P2P.send_protobuf(request, writer)
+            await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
+
+        with closing(writer):
+            writing_task = asyncio.create_task(_write_to_stream())
+            try:
+                while True:
+                    try:
+                        response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
+                    except asyncio.IncompleteReadError:  # Connection is closed
+                        break
+
+                    if err is not None:
+                        raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
+                    yield response
+
+                await writing_task
+            finally:
+                writing_task.cancel()
+
+    async def add_protobuf_handler(
+        self,
+        name: str,
+        handler: Callable[
+            [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
+        ],
+        input_protobuf_type: type,
+        *,
+        stream_input: bool = False,
+    ) -> None:
+        """
+        :param stream_input: If True, assume ``handler`` to take ``TInputStream``
+                             (not just ``TInputProtobuf``) as input.
+        """
 
 
-        return do_handle_unary_stream
+        async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
+            if stream_input:
+                input = requests
+            else:
+                count = 0
+                async for input in requests:
+                    count += 1
+                if count != 1:
+                    raise ValueError(f"Got {count} requests for handler {name} instead of one")
+
+            output = handler(input, context)
+
+            if isinstance(output, AsyncIterableABC):
+                async for item in output:
+                    yield item
+            else:
+                yield await output
+
+        await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
+
+    async def call_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        name: str,
+        input: Union[TInputProtobuf, TInputStream],
+        output_protobuf_type: type,
+    ) -> Awaitable[TOutputProtobuf]:
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        count = 0
+        async for response in responses:
+            count += 1
+        if count != 1:
+            raise ValueError(f"Got {count} responses from handler {name} instead of one")
+        return response
+
+    def iterate_protobuf_handler(
+        self,
+        peer_id: PeerID,
+        name: str,
+        input: Union[TInputProtobuf, TInputStream],
+        output_protobuf_type: type,
+    ) -> TOutputStream:
+        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
+        return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
 
 
     def _start_listening(self) -> None:
     def _start_listening(self) -> None:
         async def listen() -> None:
         async def listen() -> None:
@@ -349,34 +466,16 @@ class P2P:
                 self._listen_task = None
                 self._listen_task = None
                 self._server_stopped.clear()
                 self._server_stopped.clear()
 
 
-    async def add_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
+    async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
         if self._listen_task is None:
             self._start_listening()
             self._start_listening()
         await self._client.stream_handler(name, handler)
         await self._client.stream_handler(name, handler)
 
 
-    async def add_unary_handler(
-        self, name: str, handler: Callable[[Any, P2PContext], Any], in_proto_type: type
-    ) -> None:
-        if self._listen_task is None:
-            self._start_listening()
-        await self._client.stream_handler(name, self._handle_unary_stream(handler, name, in_proto_type))
-
-    async def call_stream_handler(
+    async def call_binary_stream_handler(
         self, peer_id: PeerID, handler_name: str
         self, peer_id: PeerID, handler_name: str
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
     ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
         return await self._client.stream_open(peer_id, (handler_name,))
         return await self._client.stream_open(peer_id, (handler_name,))
 
 
-    async def call_unary_handler(
-        self, peer_id: PeerID, handler_name: str, request_protobuf: Any, response_proto_type: type
-    ) -> Any:
-        _, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
-        with closing(writer):
-            await P2P.send_protobuf(request_protobuf, writer)
-            result, err = await P2P.receive_protobuf(response_proto_type, reader)
-            if err is not None:
-                raise P2PHandlerError(f"Failed to call unary handler {handler_name} at {peer_id}: {err.message}")
-            return result
-
     def __del__(self):
     def __del__(self):
         self._terminate()
         self._terminate()
 
 

+ 50 - 30
hivemind/p2p/servicer.py

@@ -1,7 +1,6 @@
 import asyncio
 import asyncio
-import importlib
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Any, Optional, Union
+from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
 
 
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon import P2P
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
@@ -13,6 +12,8 @@ class RPCHandler:
     handle_name: str
     handle_name: str
     request_type: type
     request_type: type
     response_type: type
     response_type: type
+    stream_input: bool
+    stream_output: bool
 
 
 
 
 class StubBase:
 class StubBase:
@@ -28,11 +29,11 @@ class StubBase:
         self._peer = peer
         self._peer = peer
 
 
 
 
-class Servicer:
+class ServicerBase:
     """
     """
     Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
     Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
 
 
-    - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P unary handlers, allowing
+    - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value
       other peers to call them. It uses type annotations for the ``request`` parameter and the return value
       to infer protobufs the methods operate with.
       to infer protobufs the methods operate with.
 
 
@@ -48,18 +49,22 @@ class Servicer:
             if method_name.startswith("rpc_") and callable(method):
             if method_name.startswith("rpc_") and callable(method):
                 handle_name = f"{class_name}.{method_name}"
                 handle_name = f"{class_name}.{method_name}"
 
 
-                hints = method.__annotations__
+                hints = get_type_hints(method)
                 try:
                 try:
-                    request_type = self._hint_to_type(hints["request"])
-                    response_type = self._hint_to_type(hints["return"])
-                except (KeyError, ValueError):
+                    request_type = hints["request"]
+                    response_type = hints["return"]
+                except KeyError:
                     raise ValueError(
                     raise ValueError(
-                        f"{handle_name} is expected to have type annotations like `dht_pb2.FindRequest` "
-                        f"(a type from the hivemind.proto module) for the `request` parameter "
-                        f"and the return value"
+                        f"{handle_name} is expected to have type annotations "
+                        f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
+                        f"for the `request` parameter and the return value"
                     )
                     )
+                request_type, stream_input = self._strip_iterator_hint(request_type)
+                response_type, stream_output = self._strip_iterator_hint(response_type)
 
 
-                self._rpc_handlers.append(RPCHandler(method_name, handle_name, request_type, response_type))
+                self._rpc_handlers.append(
+                    RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
+                )
 
 
         self._stub_type = type(
         self._stub_type = type(
             f"{class_name}Stub",
             f"{class_name}Stub",
@@ -69,14 +74,33 @@ class Servicer:
 
 
     @staticmethod
     @staticmethod
     def _make_rpc_caller(handler: RPCHandler):
     def _make_rpc_caller(handler: RPCHandler):
+        input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
+
         # This method will be added to a new Stub type (a subclass of StubBase)
         # This method will be added to a new Stub type (a subclass of StubBase)
-        async def caller(
-            self: StubBase, request: handler.request_type, timeout: Optional[float] = None
-        ) -> handler.response_type:
-            return await asyncio.wait_for(
-                self._p2p.call_unary_handler(self._peer, handler.handle_name, request, handler.response_type),
-                timeout=timeout,
-            )
+        if handler.stream_output:
+
+            def caller(
+                self: StubBase, input: input_type, timeout: None = None
+            ) -> AsyncIterator[handler.response_type]:
+                if timeout is not None:
+                    raise ValueError("Timeouts for handlers returning streams are not supported")
+
+                return self._p2p.iterate_protobuf_handler(
+                    self._peer,
+                    handler.handle_name,
+                    input,
+                    handler.response_type,
+                )
+
+        else:
+
+            async def caller(
+                self: StubBase, input: input_type, timeout: Optional[float] = None
+            ) -> handler.response_type:
+                return await asyncio.wait_for(
+                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
+                    timeout=timeout,
+                )
 
 
         caller.__name__ = handler.method_name
         caller.__name__ = handler.method_name
         return caller
         return caller
@@ -84,23 +108,19 @@ class Servicer:
     async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
     async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
         servicer = self if wrapper is None else wrapper
         servicer = self if wrapper is None else wrapper
         for handler in self._rpc_handlers:
         for handler in self._rpc_handlers:
-            await p2p.add_unary_handler(
+            await p2p.add_protobuf_handler(
                 handler.handle_name,
                 handler.handle_name,
                 getattr(servicer, handler.method_name),
                 getattr(servicer, handler.method_name),
                 handler.request_type,
                 handler.request_type,
+                stream_input=handler.stream_input,
             )
             )
 
 
     def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
     def get_stub(self, p2p: P2P, peer: PeerID) -> StubBase:
         return self._stub_type(p2p, peer)
         return self._stub_type(p2p, peer)
 
 
     @staticmethod
     @staticmethod
-    def _hint_to_type(hint: Union[type, str]) -> type:
-        if isinstance(hint, type):
-            return hint
-
-        module_name, proto_name = hint.split(".")
-        module = importlib.import_module("hivemind.proto." + module_name)
-        result = getattr(module, proto_name)
-        if not isinstance(result, type):
-            raise ValueError(f"`hivemind.proto.{hint}` is not a type")
-        return result
+    def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
+        if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):
+            return hint.__args__[0], True
+
+        return hint, False

+ 1 - 1
hivemind/proto/p2pd.proto

@@ -162,5 +162,5 @@ message PSResponse {
 }
 }
 
 
 message RPCError {
 message RPCError {
-  required string message = 1;
+  optional string message = 1;
 }
 }

+ 9 - 0
hivemind/proto/test.proto

@@ -0,0 +1,9 @@
+syntax = "proto3";
+
+message TestRequest {
+    int32 number = 1;
+}
+
+message TestResponse {
+    int32 number = 1;
+}

+ 27 - 22
tests/test_p2p_daemon.py

@@ -81,7 +81,7 @@ async def test_daemon_replica_does_not_affect_primary():
     ],
     ],
 )
 )
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"):
+async def test_call_protobuf_handler(should_cancel, replicate, handle_name="handle"):
     handler_cancelled = False
     handler_cancelled = False
     server_primary = await P2P.create()
     server_primary = await P2P.create()
     server = await replicate_if_needed(server_primary, replicate)
     server = await replicate_if_needed(server_primary, replicate)
@@ -95,7 +95,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
         return dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
 
     server_pid = server_primary._child.pid
     server_pid = server_primary._child.pid
-    await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest)
+    await server.add_protobuf_handler(handle_name, ping_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
     client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_primary = await P2P.create(initial_peers=await server.get_visible_maddrs())
@@ -108,14 +108,19 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
     expected_response = dht_pb2.PingResponse(peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()), available=True)
 
 
     if should_cancel:
     if should_cancel:
-        *_, writer = await client.call_stream_handler(server.id, handle_name)
-        with closing(writer):
-            await P2P.send_protobuf(ping_request, writer)
+        call_task = asyncio.create_task(
+            client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        )
+        await asyncio.sleep(0.25)
+
+        call_task.cancel()
 
 
-        await asyncio.sleep(1)
+        await asyncio.sleep(0.25)
         assert handler_cancelled
         assert handler_cancelled
     else:
     else:
-        actual_response = await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        actual_response = await client.call_protobuf_handler(
+            server.id, handle_name, ping_request, dht_pb2.PingResponse
+        )
         assert actual_response == expected_response
         assert actual_response == expected_response
         assert not handler_cancelled
         assert not handler_cancelled
 
 
@@ -128,13 +133,13 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
 
 
 
 
 @pytest.mark.asyncio
 @pytest.mark.asyncio
-async def test_call_unary_handler_error(handle_name="handle"):
+async def test_call_protobuf_handler_error(handle_name="handle"):
     async def error_handler(request, context):
     async def error_handler(request, context):
         raise ValueError("boom")
         raise ValueError("boom")
 
 
     server = await P2P.create()
     server = await P2P.create()
     server_pid = server._child.pid
     server_pid = server._child.pid
-    await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest)
+    await server.add_protobuf_handler(handle_name, error_handler, dht_pb2.PingRequest)
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
@@ -145,7 +150,7 @@ async def test_call_unary_handler_error(handle_name="handle"):
     ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
     ping_request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()), validate=True)
 
 
     with pytest.raises(P2PHandlerError) as excinfo:
     with pytest.raises(P2PHandlerError) as excinfo:
-        await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
+        await client.call_protobuf_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
     assert "boom" in str(excinfo.value)
     assert "boom" in str(excinfo.value)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -183,7 +188,7 @@ async def test_call_peer_single_process():
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
     handler_name = "square"
     handler_name = "square"
-    await server.add_stream_handler(handler_name, handle_square_stream)
+    await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
     client_pid = client._child.pid
@@ -191,7 +196,7 @@ async def test_call_peer_single_process():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     await validate_square_stream(reader, writer)
     await validate_square_stream(reader, writer)
 
 
     await server.shutdown()
     await server.shutdown()
@@ -206,7 +211,7 @@ async def run_server(handler_name, server_side, response_received):
     server_pid = server._child.pid
     server_pid = server._child.pid
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
-    await server.add_stream_handler(handler_name, handle_square_stream)
+    await server.add_binary_stream_handler(handler_name, handle_square_stream)
 
 
     server_side.send(server.id)
     server_side.send(server.id)
     server_side.send(await server.get_visible_maddrs())
     server_side.send(await server.get_visible_maddrs())
@@ -241,7 +246,7 @@ async def test_call_peer_different_processes():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_stream_handler(peer_id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(peer_id, handler_name)
     await validate_square_stream(reader, writer)
     await validate_square_stream(reader, writer)
 
 
     response_received.value = 1
     response_received.value = 1
@@ -268,7 +273,7 @@ async def test_error_closes_connection():
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
     handler_name = "handler"
     handler_name = "handler"
-    await server.add_stream_handler(handler_name, handle_raising_error)
+    await server.add_binary_stream_handler(handler_name, handle_raising_error)
 
 
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client = await P2P.create(initial_peers=await server.get_visible_maddrs())
     client_pid = client._child.pid
     client_pid = client._child.pid
@@ -276,7 +281,7 @@ async def test_error_closes_connection():
 
 
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"raise_error", writer)
         await P2P.send_raw_data(b"raise_error", writer)
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
         with pytest.raises(asyncio.IncompleteReadError):  # Means that the connection is closed
@@ -285,7 +290,7 @@ async def test_error_closes_connection():
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     # Despite the handler raised an exception, the server did not crash and ready for next requests
     assert is_process_running(server_pid)
     assert is_process_running(server_pid)
 
 
-    _, reader, writer = await client.call_stream_handler(server.id, handler_name)
+    _, reader, writer = await client.call_binary_stream_handler(server.id, handler_name)
     with closing(writer):
     with closing(writer):
         await P2P.send_raw_data(b"behave_normally", writer)
         await P2P.send_raw_data(b"behave_normally", writer)
         assert await P2P.receive_raw_data(reader) == b"okay"
         assert await P2P.receive_raw_data(reader) == b"okay"
@@ -305,19 +310,19 @@ async def test_handlers_on_different_replicas():
 
 
     server_primary = await P2P.create()
     server_primary = await P2P.create()
     server_id = server_primary.id
     server_id = server_primary.id
-    await server_primary.add_stream_handler("handle_primary", partial(handler, key=b"primary"))
+    await server_primary.add_binary_stream_handler("handle_primary", partial(handler, key=b"primary"))
 
 
     server_replica1 = await replicate_if_needed(server_primary, True)
     server_replica1 = await replicate_if_needed(server_primary, True)
-    await server_replica1.add_stream_handler("handle1", partial(handler, key=b"replica1"))
+    await server_replica1.add_binary_stream_handler("handle1", partial(handler, key=b"replica1"))
 
 
     server_replica2 = await replicate_if_needed(server_primary, True)
     server_replica2 = await replicate_if_needed(server_primary, True)
-    await server_replica2.add_stream_handler("handle2", partial(handler, key=b"replica2"))
+    await server_replica2.add_binary_stream_handler("handle2", partial(handler, key=b"replica2"))
 
 
     client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
     client = await P2P.create(initial_peers=await server_primary.get_visible_maddrs())
     await client.wait_for_at_least_n_peers(1)
     await client.wait_for_at_least_n_peers(1)
 
 
     for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
     for name, expected_key in [("handle_primary", b"primary"), ("handle1", b"replica1"), ("handle2", b"replica2")]:
-        _, reader, writer = await client.call_stream_handler(server_id, name)
+        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
         with closing(writer):
         with closing(writer):
             assert await P2P.receive_raw_data(reader) == expected_key
             assert await P2P.receive_raw_data(reader) == expected_key
 
 
@@ -327,7 +332,7 @@ async def test_handlers_on_different_replicas():
     # Primary does not handle replicas protocols after their shutdown
     # Primary does not handle replicas protocols after their shutdown
 
 
     for name in ["handle1", "handle2"]:
     for name in ["handle1", "handle2"]:
-        _, reader, writer = await client.call_stream_handler(server_id, name)
+        _, reader, writer = await client.call_binary_stream_handler(server_id, name)
         with pytest.raises(asyncio.IncompleteReadError), closing(writer):
         with pytest.raises(asyncio.IncompleteReadError), closing(writer):
             await P2P.receive_raw_data(reader)
             await P2P.receive_raw_data(reader)
 
 

+ 148 - 0
tests/test_p2p_servicer.py

@@ -0,0 +1,148 @@
+import asyncio
+from typing import AsyncIterator
+
+import pytest
+
+from hivemind.p2p import P2P, P2PContext, ServicerBase
+from hivemind.proto import test_pb2
+
+
+@pytest.fixture
+async def server_client():
+    server = await P2P.create()
+    client = await P2P.create(initial_peers=await server.get_visible_maddrs())
+    yield server, client
+
+    await asyncio.gather(server.shutdown(), client.shutdown())
+
+
+@pytest.mark.asyncio
+async def test_unary_unary(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_square(self, request: test_pb2.TestRequest, _: P2PContext) -> test_pb2.TestResponse:
+            return test_pb2.TestResponse(number=request.number ** 2)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    assert await stub.rpc_square(test_pb2.TestRequest(number=10)) == test_pb2.TestResponse(number=100)
+
+
+@pytest.mark.asyncio
+async def test_stream_unary(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
+            result = 0
+            async for item in request:
+                result += item.number
+            return test_pb2.TestResponse(number=result)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
+        for i in range(10):
+            yield test_pb2.TestRequest(number=i)
+
+    assert await stub.rpc_sum(generate_requests()) == test_pb2.TestResponse(number=45)
+
+
+@pytest.mark.asyncio
+async def test_unary_stream(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_count(
+            self, request: test_pb2.TestRequest, _: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
+            for i in range(request.number):
+                yield test_pb2.TestResponse(number=i)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    i = 0
+    async for item in stub.rpc_count(test_pb2.TestRequest(number=10)):
+        assert item == test_pb2.TestResponse(number=i)
+        i += 1
+    assert i == 10
+
+
+@pytest.mark.asyncio
+async def test_stream_stream(server_client):
+    class ExampleServicer(ServicerBase):
+        async def rpc_powers(
+            self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
+        ) -> AsyncIterator[test_pb2.TestResponse]:
+            async for item in request:
+                yield test_pb2.TestResponse(number=item.number ** 2)
+                yield test_pb2.TestResponse(number=item.number ** 3)
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+    stub = servicer.get_stub(client, server.id)
+
+    async def generate_requests() -> AsyncIterator[test_pb2.TestRequest]:
+        for i in range(10):
+            yield test_pb2.TestRequest(number=i)
+
+    i = 0
+    async for item in stub.rpc_powers(generate_requests()):
+        if i % 2 == 0:
+            assert item == test_pb2.TestResponse(number=(i // 2) ** 2)
+        else:
+            assert item == test_pb2.TestResponse(number=(i // 2) ** 3)
+        i += 1
+
+
+@pytest.mark.parametrize(
+    "cancel_reason",
+    ["close_connection", "close_generator"],
+)
+@pytest.mark.asyncio
+async def test_unary_stream_cancel(server_client, cancel_reason):
+    handler_cancelled = False
+
+    class ExampleServicer(ServicerBase):
+        async def rpc_wait(self, request: test_pb2.TestRequest, _: P2PContext) -> AsyncIterator[test_pb2.TestResponse]:
+            try:
+                yield test_pb2.TestResponse(number=request.number + 1)
+                await asyncio.sleep(2)
+                yield test_pb2.TestResponse(number=request.number + 2)
+            except asyncio.CancelledError:
+                nonlocal handler_cancelled
+                handler_cancelled = True
+                raise
+
+    server, client = server_client
+    servicer = ExampleServicer()
+    await servicer.add_p2p_handlers(server)
+
+    if cancel_reason == "close_connection":
+        _, reader, writer = await client.call_binary_stream_handler(server.id, "ExampleServicer.rpc_wait")
+        await P2P.send_protobuf(test_pb2.TestRequest(number=10), writer)
+        await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
+
+        response, _ = await P2P.receive_protobuf(test_pb2.TestResponse, reader)
+        assert response == test_pb2.TestResponse(number=11)
+        await asyncio.sleep(0.25)
+
+        writer.close()
+    elif cancel_reason == "close_generator":
+        stub = servicer.get_stub(client, server.id)
+        iter = stub.rpc_wait(test_pb2.TestRequest(number=10)).__aiter__()
+
+        assert await iter.__anext__() == test_pb2.TestResponse(number=11)
+        await asyncio.sleep(0.25)
+
+        await iter.aclose()
+    else:
+        assert False, f"Unknown cancel_reason = `{cancel_reason}`"
+
+    await asyncio.sleep(0.25)
+    assert handler_cancelled