|
@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
+import uuid
|
|
|
from contextlib import asynccontextmanager
|
|
|
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
|
|
|
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
|
|
|
|
|
|
from multiaddr import Multiaddr, protocols
|
|
|
|
|
@@ -54,6 +55,20 @@ class DaemonConnector:
|
|
|
else:
|
|
|
raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
|
|
|
|
|
|
+ async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
|
|
|
+ """
|
|
|
+ Open connection to daemon and upgrade it to a persistent one
|
|
|
+ """
|
|
|
+ reader, writer = await self.open_connection()
|
|
|
+ req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
|
|
|
+ await write_pbmsg(writer, req)
|
|
|
+
|
|
|
+ return reader, writer
|
|
|
+
|
|
|
+
|
|
|
+TUnaryHandler = Callable[[bytes], bytes]
|
|
|
+CallID = uuid.UUID
|
|
|
+
|
|
|
|
|
|
class ControlClient:
|
|
|
DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
|
|
@@ -65,6 +80,54 @@ class ControlClient:
|
|
|
self.daemon_connector = daemon_connector
|
|
|
self.handlers: Dict[str, StreamHandler] = {}
|
|
|
|
|
|
+ # persistent connection readers & writers
|
|
|
+ self._pers_conn_open: bool = False
|
|
|
+ self.unary_handlers: Dict[str, TUnaryHandler] = {}
|
|
|
+
|
|
|
+ self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
|
|
|
+ self.pending_calls: Dict[CallID, asyncio.Future] = {}
|
|
|
+
|
|
|
+ async def read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
+ while True:
|
|
|
+ resp: p2pd_pb.Response = p2pd_pb.Response() # type: ignore
|
|
|
+ await read_pbmsg_safe(reader, resp)
|
|
|
+
|
|
|
+ if resp.callUnaryResponse:
|
|
|
+ call_id = uuid.UUID(bytes=resp.callUnaryResponse.callId)
|
|
|
+
|
|
|
+ if call_id in self.pending_calls and resp.data:
|
|
|
+ self.pending_calls[call_id].set_result(call_id)
|
|
|
+ elif call_id in self.pending_calls and resp.error:
|
|
|
+ remote_exc = RemoteException(str(resp.error))
|
|
|
+ self.pending_calls[call_id].set_exception(remote_exc)
|
|
|
+ else:
|
|
|
+ logger.debug(f"received unexpected unary call")
|
|
|
+
|
|
|
+ elif resp.requestHandling:
|
|
|
+ # asyncio.create_task(self.read)
|
|
|
+ pass
|
|
|
+
|
|
|
+ async def _handle_persistent_request(self, request):
|
|
|
+ assert request.proto in self.unary_handlers
|
|
|
+
|
|
|
+ try:
|
|
|
+ response_payload: bytes = await self.unary_protocols[request.protocol](request.payload)
|
|
|
+ response = p2pd_pb.CallUnaryResponse(
|
|
|
+ call_id=request.call_id,
|
|
|
+ data=response_payload)
|
|
|
+ except Exception as e:
|
|
|
+ response = p2pd_pb.CallUnaryResponse(
|
|
|
+ call_id=request.call_id,
|
|
|
+ error=repr(e))
|
|
|
+
|
|
|
+ await self.pending_messages.put(
|
|
|
+ p2pd_pb.Request(type=p2pd_pb.Request.UNARY_RESPONSE, response=response))
|
|
|
+
|
|
|
+ async def write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
+ while True:
|
|
|
+ msg = await self.pending_messages.get()
|
|
|
+ await write_pbmsg(writer, msg)
|
|
|
+
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
|
await read_pbmsg_safe(reader, pb_stream_info)
|
|
@@ -93,6 +156,48 @@ class ControlClient:
|
|
|
async with server:
|
|
|
yield self
|
|
|
|
|
|
+ async def _ensure_persistent_conn(self):
|
|
|
+ if not self._pers_conn_open:
|
|
|
+ reader, writer = await self.daemon_connector.open_persistent_connection()
|
|
|
+ asyncio.create_task(self.read_from_persistent_conn(reader))
|
|
|
+ asyncio.create_task(self.write_to_persistent_conn(writer))
|
|
|
+
|
|
|
+ async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
|
|
|
+ await self._ensure_persistent_conn()
|
|
|
+
|
|
|
+ add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
|
|
|
+ req = p2pd_pb.Request(
|
|
|
+ type=p2pd_pb.Request.ADD_UNARY_HANDLER,
|
|
|
+ addUnaryHandler=add_unary_handler_req,
|
|
|
+ )
|
|
|
+ await self.pending_messages.put(req)
|
|
|
+
|
|
|
+ if self.unary_handlers.get(proto):
|
|
|
+ raise ValueError(f"Handler for protocol {proto} already assigned")
|
|
|
+ self.unary_handlers[proto] = handler
|
|
|
+
|
|
|
+ async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
|
|
|
+ call_id = uuid.uuid4()
|
|
|
+ call_unary_req = p2pd_pb.CallUnaryRequest(
|
|
|
+ peer=peer_id.to_bytes(),
|
|
|
+ proto=proto,
|
|
|
+ data=data,
|
|
|
+ callId=call_id,
|
|
|
+ )
|
|
|
+ req = p2pd_pb.Request(
|
|
|
+ type=p2pd_pb.Request.CALL_UNARY,
|
|
|
+ callUnary=call_unary_req,
|
|
|
+ )
|
|
|
+
|
|
|
+ await self._ensure_persistent_conn()
|
|
|
+
|
|
|
+ try:
|
|
|
+ self.pending_calls[call_id] = asyncio.Future()
|
|
|
+ await self.pending_messages.put(req)
|
|
|
+ return await self.pending_calls[call_id]
|
|
|
+ finally:
|
|
|
+ await self.pending_calls.pop(call_id)
|
|
|
+
|
|
|
async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
|
|
|
reader, writer = await self.daemon_connector.open_connection()
|
|
|
req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
|
|
@@ -179,3 +284,9 @@ class ControlClient:
|
|
|
|
|
|
# if success, add the handler to the dict
|
|
|
self.handlers[proto] = handler_cb
|
|
|
+
|
|
|
+
|
|
|
+class RemoteException(Exception):
|
|
|
+ """
|
|
|
+ Raised if remote handled a request with an exception
|
|
|
+ """
|