|
@@ -87,6 +87,7 @@ class ControlClient:
|
|
|
self._ensure_conn_lock = asyncio.Lock()
|
|
|
self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
|
|
|
self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
+ self.handler_tasks: Dict[CallID, asyncio.Task] = {}
|
|
|
|
|
|
@asynccontextmanager
|
|
|
async def listen(self) -> AsyncIterator["ControlClient"]:
|
|
@@ -121,7 +122,12 @@ class ControlClient:
|
|
|
logger.debug(f"received unexpected unary call")
|
|
|
|
|
|
elif resp.HasField("requestHandling"):
|
|
|
- asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
+ handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
+ # TODO: fix race condition at the of ._handle_persistent_request(...)
|
|
|
+ self.handler_tasks[call_id] = handler_task
|
|
|
+
|
|
|
+ elif call_id in self.handler_tasks and resp.HasField("cancel"):
|
|
|
+ self.handler_tasks[call_id].cancel()
|
|
|
|
|
|
async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
while True:
|
|
@@ -135,10 +141,12 @@ class ControlClient:
|
|
|
remote_id = PeerID(request.peer)
|
|
|
response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
|
|
|
response = p2pd_pb.CallUnaryResponse(result=response_payload)
|
|
|
+
|
|
|
except Exception as e:
|
|
|
response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
|
|
|
|
|
|
await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
|
|
|
+ self.handler_tasks.pop(call_id)
|
|
|
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
@@ -152,6 +160,14 @@ class ControlClient:
|
|
|
raise DispatchFailure(e)
|
|
|
await handler(stream_info, reader, writer)
|
|
|
|
|
|
+ async def _send_call_cancel(self, call_id: uuid.UUID):
|
|
|
+ await self.pending_messages.put(
|
|
|
+ p2pd_pb.PCRequest(
|
|
|
+ callId=call_id.bytes,
|
|
|
+ cancel=p2pd_pb.Cancel(),
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
async def _ensure_persistent_conn(self):
|
|
|
if not self._pers_conn_open:
|
|
|
async with self._ensure_conn_lock:
|
|
@@ -193,6 +209,11 @@ class ControlClient:
|
|
|
self.pending_calls[call_id] = asyncio.Future()
|
|
|
await self.pending_messages.put(req)
|
|
|
return await self.pending_calls[call_id]
|
|
|
+
|
|
|
+ except asyncio.CancelledError:
|
|
|
+ await self._send_call_cancel(call_id)
|
|
|
+ raise
|
|
|
+
|
|
|
finally:
|
|
|
await self.pending_calls.pop(call_id)
|
|
|
|