|
@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
|
|
|
|
|
|
import asyncio
|
|
|
import uuid
|
|
|
-from contextlib import asynccontextmanager
|
|
|
+from contextlib import asynccontextmanager, closing
|
|
|
from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
|
|
|
|
|
|
from multiaddr import Multiaddr, protocols
|
|
@@ -87,6 +87,10 @@ class ControlClient:
|
|
|
self._ensure_conn_lock = asyncio.Lock()
|
|
|
self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
|
|
|
self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
+ self.handler_tasks: Dict[CallID, asyncio.Task] = {}
|
|
|
+
|
|
|
+ self._read_task: Optional[asyncio.Task] = None
|
|
|
+ self._write_task: Optional[asyncio.Task] = None
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def listen(self) -> AsyncIterator["ControlClient"]:
|
|
@@ -101,8 +105,14 @@ class ControlClient:
|
|
|
else:
|
|
|
raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
|
|
|
|
|
|
- async with server:
|
|
|
- yield self
|
|
|
+ try:
|
|
|
+ async with server:
|
|
|
+ yield self
|
|
|
+ finally:
|
|
|
+ if self._read_task is not None:
|
|
|
+ self._read_task.cancel()
|
|
|
+ if self._write_task is not None:
|
|
|
+ self._write_task.cancel()
|
|
|
|
|
|
async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
while True:
|
|
@@ -121,9 +131,14 @@ class ControlClient:
|
|
|
logger.debug(f"received unexpected unary call")
|
|
|
|
|
|
elif resp.HasField("requestHandling"):
|
|
|
- asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
+ handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
+ self.handler_tasks[call_id] = handler_task
|
|
|
+
|
|
|
+ elif call_id in self.handler_tasks and resp.HasField("cancel"):
|
|
|
+ self.handler_tasks[call_id].cancel()
|
|
|
|
|
|
async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
+ #with closing(writer):
|
|
|
while True:
|
|
|
msg = await self.pending_messages.get()
|
|
|
await write_pbmsg(writer, msg)
|
|
@@ -135,10 +150,12 @@ class ControlClient:
|
|
|
remote_id = PeerID(request.peer)
|
|
|
response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
|
|
|
response = p2pd_pb.CallUnaryResponse(result=response_payload)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
|
|
|
|
|
|
await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
|
|
|
+ self.handler_tasks.pop(call_id)
|
|
|
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
@@ -152,13 +169,23 @@ class ControlClient:
|
|
|
raise DispatchFailure(e)
|
|
|
await handler(stream_info, reader, writer)
|
|
|
|
|
|
+ async def _send_call_cancel(self, call_id: uuid.UUID):
|
|
|
+ await self.pending_messages.put(
|
|
|
+ p2pd_pb.PCRequest(
|
|
|
+ callId=call_id.bytes,
|
|
|
+ cancel=p2pd_pb.Cancel(),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
async def _ensure_persistent_conn(self):
|
|
|
if not self._pers_conn_open:
|
|
|
async with self._ensure_conn_lock:
|
|
|
if not self._pers_conn_open:
|
|
|
reader, writer = await self.daemon_connector.open_persistent_connection()
|
|
|
- asyncio.create_task(self._read_from_persistent_conn(reader))
|
|
|
- asyncio.create_task(self._write_to_persistent_conn(writer))
|
|
|
+
|
|
|
+ self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
|
|
|
+ self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
|
|
|
+
|
|
|
self._pers_conn_open = True
|
|
|
|
|
|
async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
|
|
@@ -193,8 +220,13 @@ class ControlClient:
|
|
|
self.pending_calls[call_id] = asyncio.Future()
|
|
|
await self.pending_messages.put(req)
|
|
|
return await self.pending_calls[call_id]
|
|
|
+
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ asyncio.create_task(self._send_call_cancel(call_id))
|
|
|
+ raise
|
|
|
+
|
|
|
finally:
|
|
|
- await self.pending_calls.pop(call_id)
|
|
|
+ self.pending_calls.pop(call_id, None)
|
|
|
|
|
|
async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
|
|
|
reader, writer = await self.daemon_connector.open_connection()
|