|
@@ -5,12 +5,14 @@ from collections.abc import AsyncIterable as AsyncIterableABC
|
|
|
from contextlib import closing, suppress
|
|
|
from dataclasses import dataclass
|
|
|
from importlib.resources import path
|
|
|
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
|
|
|
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union
|
|
|
|
|
|
+from google.protobuf.message import Message
|
|
|
from multiaddr import Multiaddr
|
|
|
|
|
|
import hivemind.hivemind_cli as cli
|
|
|
import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
|
|
|
+from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
|
|
|
from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
|
|
|
from hivemind.proto.p2pd_pb2 import RPCError
|
|
|
from hivemind.utils.asyncio import aiter, asingle
|
|
@@ -27,7 +29,6 @@ class P2PContext(object):
|
|
|
handle_name: str
|
|
|
local_id: PeerID
|
|
|
remote_id: PeerID = None
|
|
|
- remote_maddr: Multiaddr = None
|
|
|
|
|
|
|
|
|
class P2P:
|
|
@@ -65,6 +66,7 @@ class P2P:
|
|
|
|
|
|
def __init__(self):
|
|
|
self.peer_id = None
|
|
|
+ self._client = None
|
|
|
self._child = None
|
|
|
self._alive = False
|
|
|
self._reader_task = None
|
|
@@ -90,6 +92,7 @@ class P2P:
|
|
|
use_auto_relay: bool = False,
|
|
|
relay_hop_limit: int = 0,
|
|
|
startup_timeout: float = 15,
|
|
|
+ idle_timeout: float = 30,
|
|
|
) -> "P2P":
|
|
|
"""
|
|
|
Start a new p2pd process and connect to it.
|
|
@@ -111,6 +114,8 @@ class P2P:
|
|
|
:param use_auto_relay: enables autorelay
|
|
|
:param relay_hop_limit: sets the hop limit for hop relays
|
|
|
:param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
|
|
|
+ :param idle_timeout: kill daemon if client has been idle for a given number of
|
|
|
+ seconds before opening persistent streams
|
|
|
:return: a wrapper for the p2p daemon
|
|
|
"""
|
|
|
|
|
@@ -150,6 +155,7 @@ class P2P:
|
|
|
relayDiscovery=use_relay_discovery,
|
|
|
autoRelay=use_auto_relay,
|
|
|
relayHopLimit=relay_hop_limit,
|
|
|
+ idleTimeout=f"{idle_timeout}s",
|
|
|
b=need_bootstrap,
|
|
|
**process_kwargs,
|
|
|
)
|
|
@@ -167,7 +173,7 @@ class P2P:
|
|
|
await self.shutdown()
|
|
|
raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
|
|
|
|
|
|
- self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
|
|
|
+ self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
|
|
|
await self._ping_daemon()
|
|
|
return self
|
|
|
|
|
@@ -189,7 +195,7 @@ class P2P:
|
|
|
self._daemon_listen_maddr = daemon_listen_maddr
|
|
|
self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
|
|
|
|
|
|
- self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
|
|
|
+ self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
|
|
|
|
|
|
await self._ping_daemon()
|
|
|
return self
|
|
@@ -258,7 +264,7 @@ class P2P:
|
|
|
|
|
|
@staticmethod
|
|
|
async def receive_protobuf(
|
|
|
- input_protobuf_type: type, reader: asyncio.StreamReader
|
|
|
+ input_protobuf_type: Type[Message], reader: asyncio.StreamReader
|
|
|
) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
|
|
|
msg_type = await reader.readexactly(1)
|
|
|
if msg_type == P2P.MESSAGE_MARKER:
|
|
@@ -279,7 +285,7 @@ class P2P:
|
|
|
self,
|
|
|
name: str,
|
|
|
handler: Callable[[TInputStream, P2PContext], TOutputStream],
|
|
|
- input_protobuf_type: type,
|
|
|
+ input_protobuf_type: Type[Message],
|
|
|
max_prefetch: int = 5,
|
|
|
) -> None:
|
|
|
"""
|
|
@@ -297,7 +303,6 @@ class P2P:
|
|
|
handle_name=name,
|
|
|
local_id=self.peer_id,
|
|
|
remote_id=stream_info.peer_id,
|
|
|
- remote_maddr=stream_info.addr,
|
|
|
)
|
|
|
requests = asyncio.Queue(max_prefetch)
|
|
|
|
|
@@ -349,7 +354,7 @@ class P2P:
|
|
|
await self.add_binary_stream_handler(name, _handle_stream)
|
|
|
|
|
|
async def _iterate_protobuf_stream_handler(
|
|
|
- self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
|
|
|
+ self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Type[Message]
|
|
|
) -> TOutputStream:
|
|
|
_, reader, writer = await self.call_binary_stream_handler(peer_id, name)
|
|
|
|
|
@@ -381,15 +386,22 @@ class P2P:
|
|
|
handler: Callable[
|
|
|
[Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
|
|
|
],
|
|
|
- input_protobuf_type: type,
|
|
|
+ input_protobuf_type: Type[Message],
|
|
|
*,
|
|
|
stream_input: bool = False,
|
|
|
+ stream_output: bool = False,
|
|
|
) -> None:
|
|
|
"""
|
|
|
:param stream_input: If True, assume ``handler`` to take ``TInputStream``
|
|
|
(not just ``TInputProtobuf``) as input.
|
|
|
+ :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
|
|
|
+ (not ``Awaitable[TOutputProtobuf]``).
|
|
|
"""
|
|
|
|
|
|
+ if not stream_input and not stream_output:
|
|
|
+ await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
|
|
|
+ return
|
|
|
+
|
|
|
async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
|
|
|
input = requests if stream_input else await asingle(requests)
|
|
|
output = handler(input, context)
|
|
@@ -402,23 +414,65 @@ class P2P:
|
|
|
|
|
|
await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
|
|
|
|
|
|
+ async def _add_protobuf_unary_handler(
|
|
|
+ self,
|
|
|
+ handle_name: str,
|
|
|
+ handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
|
|
|
+ input_protobuf_type: Type[Message],
|
|
|
+ ) -> None:
|
|
|
+ """
|
|
|
+ Register a request-response (unary) handler. Unary requests and responses
|
|
|
+ are sent through persistent multiplexed connections to the daemon for the
|
|
|
+ sake of reducing the number of open files.
|
|
|
+ :param handle_name: name of the handler (protocol id)
|
|
|
+ :param handler: function handling the unary requests
|
|
|
+ :param input_protobuf_type: protobuf type of the request
|
|
|
+ """
|
|
|
+
|
|
|
+ async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
|
|
|
+ input_serialized = input_protobuf_type.FromString(request)
|
|
|
+ context = P2PContext(
|
|
|
+ handle_name=handle_name,
|
|
|
+ local_id=self.peer_id,
|
|
|
+ remote_id=remote_id,
|
|
|
+ )
|
|
|
+
|
|
|
+ response = await handler(input_serialized, context)
|
|
|
+ return response.SerializeToString()
|
|
|
+
|
|
|
+ await self._client.add_unary_handler(handle_name, _unary_handler)
|
|
|
+
|
|
|
async def call_protobuf_handler(
|
|
|
self,
|
|
|
peer_id: PeerID,
|
|
|
name: str,
|
|
|
input: Union[TInputProtobuf, TInputStream],
|
|
|
- output_protobuf_type: type,
|
|
|
+ output_protobuf_type: Type[Message],
|
|
|
) -> Awaitable[TOutputProtobuf]:
|
|
|
- requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
|
|
|
- responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
|
|
|
+
|
|
|
+ if not isinstance(input, AsyncIterableABC):
|
|
|
+ return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
|
|
|
+
|
|
|
+ responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
|
|
|
return await asingle(responses)
|
|
|
|
|
|
+ async def _call_unary_protobuf_handler(
|
|
|
+ self,
|
|
|
+ peer_id: PeerID,
|
|
|
+ handle_name: str,
|
|
|
+ input: TInputProtobuf,
|
|
|
+ output_protobuf_type: Type[Message],
|
|
|
+ ) -> Awaitable[TOutputProtobuf]:
|
|
|
+ serialized_input = input.SerializeToString()
|
|
|
+ response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
|
|
|
+ return output_protobuf_type.FromString(response)
|
|
|
+
|
|
|
def iterate_protobuf_handler(
|
|
|
self,
|
|
|
peer_id: PeerID,
|
|
|
name: str,
|
|
|
input: Union[TInputProtobuf, TInputStream],
|
|
|
- output_protobuf_type: type,
|
|
|
+ output_protobuf_type: Type[Message],
|
|
|
) -> TOutputStream:
|
|
|
requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
|
|
|
return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
|
|
@@ -453,6 +507,8 @@ class P2P:
|
|
|
await self._child.wait()
|
|
|
|
|
|
def _terminate(self) -> None:
|
|
|
+ if self._client is not None:
|
|
|
+ self._client.close()
|
|
|
if self._listen_task is not None:
|
|
|
self._listen_task.cancel()
|
|
|
if self._reader_task is not None:
|
|
@@ -501,11 +557,3 @@ class P2P:
|
|
|
|
|
|
if not ready.done():
|
|
|
ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
|
|
|
-
|
|
|
-
|
|
|
-class P2PDaemonError(RuntimeError):
|
|
|
- pass
|
|
|
-
|
|
|
-
|
|
|
-class P2PHandlerError(Exception):
|
|
|
- pass
|