|
@@ -1,19 +1,20 @@
|
|
|
import asyncio
|
|
|
import os
|
|
|
import secrets
|
|
|
+from collections.abc import AsyncIterable as AsyncIterableABC
|
|
|
from contextlib import closing, suppress
|
|
|
from dataclasses import dataclass
|
|
|
from importlib.resources import path
|
|
|
from subprocess import Popen
|
|
|
-from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
|
|
|
+from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
|
|
|
|
|
|
-import google.protobuf
|
|
|
from multiaddr import Multiaddr
|
|
|
|
|
|
import hivemind.hivemind_cli as cli
|
|
|
import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
|
|
|
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
|
|
|
-from hivemind.proto import p2pd_pb2
|
|
|
+from hivemind.proto.p2pd_pb2 import RPCError
|
|
|
+from hivemind.utils.asyncio import aiter
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
@@ -39,18 +40,19 @@ class P2P:
|
|
|
use the public IPFS network (https://ipfs.io).
|
|
|
|
|
|
For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
|
|
|
- - `P2P.add_unary_handler` accepts a protobuf message and returns another protobuf
|
|
|
- - `P2P.add_stream_handler` transfers raw data using bi-directional streaming interface
|
|
|
+ - `P2P.add_protobuf_handler` accepts a protobuf message and returns another protobuf
|
|
|
+ - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
|
|
|
|
|
|
- To access these handlers, a P2P instance can `P2P.call_unary_handler`/`P2P.call_stream_handler`,
|
|
|
+ To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
|
|
|
using the recipient's unique `P2P.id` and the name of the corresponding handler.
|
|
|
"""
|
|
|
|
|
|
HEADER_LEN = 8
|
|
|
BYTEORDER = "big"
|
|
|
- PB_HEADER_LEN = 1
|
|
|
- RESULT_MESSAGE = b"\x00"
|
|
|
- ERROR_MESSAGE = b"\x01"
|
|
|
+ MESSAGE_MARKER = b"\x00"
|
|
|
+ ERROR_MARKER = b"\x01"
|
|
|
+ END_OF_STREAM = RPCError()
|
|
|
+
|
|
|
DHT_MODE_MAPPING = {
|
|
|
"dht": {"dht": 1},
|
|
|
"dht_server": {"dhtServer": 1},
|
|
@@ -253,15 +255,6 @@ class P2P:
|
|
|
writer.write(data[offset : offset + chunk_size])
|
|
|
await writer.drain()
|
|
|
|
|
|
- @staticmethod
|
|
|
- async def send_protobuf(protobuf, writer: asyncio.StreamWriter) -> None:
|
|
|
- if isinstance(protobuf, p2pd_pb2.RPCError):
|
|
|
- await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
|
|
|
- else:
|
|
|
- await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
|
|
|
-
|
|
|
- await P2P.send_raw_data(protobuf.SerializeToString(), writer)
|
|
|
-
|
|
|
@staticmethod
|
|
|
async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
|
|
|
header = await reader.readexactly(P2P.HEADER_LEN)
|
|
@@ -269,68 +262,192 @@ class P2P:
|
|
|
data = await reader.readexactly(content_length)
|
|
|
return data
|
|
|
|
|
|
+ TInputProtobuf = TypeVar("TInputProtobuf")
|
|
|
+ TOutputProtobuf = TypeVar("TOutputProtobuf")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
|
|
|
+ if isinstance(protobuf, RPCError):
|
|
|
+ writer.write(P2P.ERROR_MARKER)
|
|
|
+ else:
|
|
|
+ writer.write(P2P.MESSAGE_MARKER)
|
|
|
+ await P2P.send_raw_data(protobuf.SerializeToString(), writer)
|
|
|
+
|
|
|
@staticmethod
|
|
|
async def receive_protobuf(
|
|
|
- in_proto_type: type, reader: asyncio.StreamReader
|
|
|
- ) -> Tuple[Any, Optional[p2pd_pb2.RPCError]]:
|
|
|
- msg_type = await P2P.receive_raw_data(reader)
|
|
|
- if msg_type == P2P.RESULT_MESSAGE:
|
|
|
- protobuf = in_proto_type()
|
|
|
+ input_protobuf_type: type, reader: asyncio.StreamReader
|
|
|
+ ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
|
|
|
+ msg_type = await reader.readexactly(1)
|
|
|
+ if msg_type == P2P.MESSAGE_MARKER:
|
|
|
+ protobuf = input_protobuf_type()
|
|
|
protobuf.ParseFromString(await P2P.receive_raw_data(reader))
|
|
|
return protobuf, None
|
|
|
- elif msg_type == P2P.ERROR_MESSAGE:
|
|
|
- protobuf = p2pd_pb2.RPCError()
|
|
|
+ elif msg_type == P2P.ERROR_MARKER:
|
|
|
+ protobuf = RPCError()
|
|
|
protobuf.ParseFromString(await P2P.receive_raw_data(reader))
|
|
|
return None, protobuf
|
|
|
else:
|
|
|
raise TypeError("Invalid Protobuf message type")
|
|
|
|
|
|
- def _handle_unary_stream(self, handler: Callable[[Any, P2PContext], Any], handle_name: str, in_proto_type: type):
|
|
|
- async def watchdog(reader: asyncio.StreamReader) -> None:
|
|
|
- await reader.read(n=1)
|
|
|
- raise P2PInterruptedError()
|
|
|
+ TInputStream = AsyncIterator[TInputProtobuf]
|
|
|
+ TOutputStream = AsyncIterator[TOutputProtobuf]
|
|
|
+
|
|
|
+ async def _add_protobuf_stream_handler(
|
|
|
+ self,
|
|
|
+ name: str,
|
|
|
+ handler: Callable[[TInputStream, P2PContext], TOutputStream],
|
|
|
+ input_protobuf_type: type,
|
|
|
+ max_prefetch: int = 0,
|
|
|
+ ) -> None:
|
|
|
+ """
|
|
|
+ :param max_prefetch: Maximum number of items to prefetch from the request stream.
|
|
|
+ ``max_prefetch <= 0`` means unlimited (default).
|
|
|
+
|
|
|
+ :note: Since the cancel messages are sent via the input stream,
|
|
|
+ they will not be received while the prefetch buffer is full.
|
|
|
+ """
|
|
|
+
|
|
|
+ if self._listen_task is None:
|
|
|
+ self._start_listening()
|
|
|
|
|
|
- async def do_handle_unary_stream(
|
|
|
+ async def _handle_stream(
|
|
|
stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
|
) -> None:
|
|
|
- with closing(writer):
|
|
|
+ context = P2PContext(
|
|
|
+ handle_name=name,
|
|
|
+ local_id=self.id,
|
|
|
+ remote_id=stream_info.peer_id,
|
|
|
+ remote_maddr=stream_info.addr,
|
|
|
+ )
|
|
|
+ requests = asyncio.Queue(max_prefetch)
|
|
|
+
|
|
|
+ async def _read_stream() -> P2P.TInputStream:
|
|
|
+ while True:
|
|
|
+ request = await requests.get()
|
|
|
+ if request is None:
|
|
|
+ break
|
|
|
+ yield request
|
|
|
+
|
|
|
+ async def _process_stream() -> None:
|
|
|
try:
|
|
|
- request, err = await P2P.receive_protobuf(in_proto_type, reader)
|
|
|
- except asyncio.IncompleteReadError:
|
|
|
- logger.debug(f"Incomplete read while receiving request from peer in {handle_name}")
|
|
|
- return
|
|
|
- except google.protobuf.message.DecodeError as error:
|
|
|
- logger.debug(
|
|
|
- f"Failed to decode request protobuf " f"of type {in_proto_type} in {handle_name}: {error}"
|
|
|
- )
|
|
|
- return
|
|
|
- if err is not None:
|
|
|
- logger.debug(f"Got an error instead of a request in {handle_name}: {err}")
|
|
|
-
|
|
|
- context = P2PContext(
|
|
|
- handle_name=handle_name,
|
|
|
- local_id=self.id,
|
|
|
- remote_id=stream_info.peer_id,
|
|
|
- remote_maddr=stream_info.addr,
|
|
|
- )
|
|
|
- done, pending = await asyncio.wait(
|
|
|
- [watchdog(reader), handler(request, context)], return_when=asyncio.FIRST_COMPLETED
|
|
|
- )
|
|
|
+ async for response in handler(_read_stream(), context):
|
|
|
+ await P2P.send_protobuf(response, writer)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("Exception while processing stream and sending responses:", exc_info=True)
|
|
|
+ await P2P.send_protobuf(RPCError(message=str(e)), writer)
|
|
|
+
|
|
|
+ with closing(writer):
|
|
|
+ processing_task = asyncio.create_task(_process_stream())
|
|
|
try:
|
|
|
- result = done.pop().result()
|
|
|
- await P2P.send_protobuf(result, writer)
|
|
|
- except P2PInterruptedError:
|
|
|
- pass
|
|
|
- except Exception as exc:
|
|
|
- error = p2pd_pb2.RPCError(message=str(exc))
|
|
|
- await P2P.send_protobuf(error, writer)
|
|
|
+ while True:
|
|
|
+ receive_task = asyncio.create_task(P2P.receive_protobuf(input_protobuf_type, reader))
|
|
|
+ await asyncio.wait({processing_task, receive_task}, return_when=asyncio.FIRST_COMPLETED)
|
|
|
+
|
|
|
+ if processing_task.done():
|
|
|
+ receive_task.cancel()
|
|
|
+ return
|
|
|
+
|
|
|
+ if receive_task.done():
|
|
|
+ try:
|
|
|
+ request, _ = await receive_task
|
|
|
+ except asyncio.IncompleteReadError: # Connection is closed (the client cancelled or died)
|
|
|
+ return
|
|
|
+ await requests.put(request) # `request` is None for the end-of-stream message
|
|
|
+ except Exception:
|
|
|
+ logger.warning("Exception while receiving requests:", exc_info=True)
|
|
|
finally:
|
|
|
- if pending:
|
|
|
- for task in pending:
|
|
|
- task.cancel()
|
|
|
- await asyncio.wait(pending)
|
|
|
+ processing_task.cancel()
|
|
|
+
|
|
|
+ await self._client.stream_handler(name, _handle_stream)
|
|
|
+
|
|
|
+ async def _iterate_protobuf_stream_handler(
|
|
|
+ self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
|
|
|
+ ) -> TOutputStream:
|
|
|
+ _, reader, writer = await self._client.stream_open(peer_id, (name,))
|
|
|
+
|
|
|
+ async def _write_to_stream() -> None:
|
|
|
+ async for request in requests:
|
|
|
+ await P2P.send_protobuf(request, writer)
|
|
|
+ await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
|
|
|
+
|
|
|
+ with closing(writer):
|
|
|
+ writing_task = asyncio.create_task(_write_to_stream())
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
|
|
|
+ except asyncio.IncompleteReadError: # Connection is closed
|
|
|
+ break
|
|
|
+
|
|
|
+ if err is not None:
|
|
|
+ raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
|
|
|
+ yield response
|
|
|
+
|
|
|
+ await writing_task
|
|
|
+ finally:
|
|
|
+ writing_task.cancel()
|
|
|
+
|
|
|
+ async def add_protobuf_handler(
|
|
|
+ self,
|
|
|
+ name: str,
|
|
|
+ handler: Callable[
|
|
|
+ [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
|
|
|
+ ],
|
|
|
+ input_protobuf_type: type,
|
|
|
+ *,
|
|
|
+ stream_input: bool = False,
|
|
|
+ ) -> None:
|
|
|
+ """
|
|
|
+ :param stream_input: If True, assume ``handler`` to take ``TInputStream``
|
|
|
+ (not just ``TInputProtobuf``) as input.
|
|
|
+ """
|
|
|
|
|
|
- return do_handle_unary_stream
|
|
|
+ async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
|
|
|
+ if stream_input:
|
|
|
+ input = requests
|
|
|
+ else:
|
|
|
+ count = 0
|
|
|
+ async for input in requests:
|
|
|
+ count += 1
|
|
|
+ if count != 1:
|
|
|
+ raise ValueError(f"Got {count} requests for handler {name} instead of one")
|
|
|
+
|
|
|
+ output = handler(input, context)
|
|
|
+
|
|
|
+ if isinstance(output, AsyncIterableABC):
|
|
|
+ async for item in output:
|
|
|
+ yield item
|
|
|
+ else:
|
|
|
+ yield await output
|
|
|
+
|
|
|
+ await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
|
|
|
+
|
|
|
+ async def call_protobuf_handler(
|
|
|
+ self,
|
|
|
+ peer_id: PeerID,
|
|
|
+ name: str,
|
|
|
+ input: Union[TInputProtobuf, TInputStream],
|
|
|
+ output_protobuf_type: type,
|
|
|
+ ) -> Awaitable[TOutputProtobuf]:
|
|
|
+ requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
|
|
|
+ responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
|
|
|
+
|
|
|
+ count = 0
|
|
|
+ async for response in responses:
|
|
|
+ count += 1
|
|
|
+ if count != 1:
|
|
|
+ raise ValueError(f"Got {count} responses from handler {name} instead of one")
|
|
|
+ return response
|
|
|
+
|
|
|
+ def iterate_protobuf_handler(
|
|
|
+ self,
|
|
|
+ peer_id: PeerID,
|
|
|
+ name: str,
|
|
|
+ input: Union[TInputProtobuf, TInputStream],
|
|
|
+ output_protobuf_type: type,
|
|
|
+ ) -> TOutputStream:
|
|
|
+ requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
|
|
|
+ return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
|
|
|
|
|
|
def _start_listening(self) -> None:
|
|
|
async def listen() -> None:
|
|
@@ -349,34 +466,16 @@ class P2P:
|
|
|
self._listen_task = None
|
|
|
self._server_stopped.clear()
|
|
|
|
|
|
- async def add_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
|
|
|
+ async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
|
|
|
if self._listen_task is None:
|
|
|
self._start_listening()
|
|
|
await self._client.stream_handler(name, handler)
|
|
|
|
|
|
- async def add_unary_handler(
|
|
|
- self, name: str, handler: Callable[[Any, P2PContext], Any], in_proto_type: type
|
|
|
- ) -> None:
|
|
|
- if self._listen_task is None:
|
|
|
- self._start_listening()
|
|
|
- await self._client.stream_handler(name, self._handle_unary_stream(handler, name, in_proto_type))
|
|
|
-
|
|
|
- async def call_stream_handler(
|
|
|
+ async def call_binary_stream_handler(
|
|
|
self, peer_id: PeerID, handler_name: str
|
|
|
) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
|
|
|
return await self._client.stream_open(peer_id, (handler_name,))
|
|
|
|
|
|
- async def call_unary_handler(
|
|
|
- self, peer_id: PeerID, handler_name: str, request_protobuf: Any, response_proto_type: type
|
|
|
- ) -> Any:
|
|
|
- _, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
|
|
|
- with closing(writer):
|
|
|
- await P2P.send_protobuf(request_protobuf, writer)
|
|
|
- result, err = await P2P.receive_protobuf(response_proto_type, reader)
|
|
|
- if err is not None:
|
|
|
- raise P2PHandlerError(f"Failed to call unary handler {handler_name} at {peer_id}: {err.message}")
|
|
|
- return result
|
|
|
-
|
|
|
def __del__(self):
|
|
|
self._terminate()
|
|
|
|