|
@@ -80,14 +80,13 @@ class ControlClient:
|
|
|
self.daemon_connector = daemon_connector
|
|
|
self.handlers: Dict[str, StreamHandler] = {}
|
|
|
|
|
|
- # persistent connection readers & writers
|
|
|
- self._pers_conn_open: bool = False
|
|
|
+ self._is_persistent_conn_open: bool = False
|
|
|
self.unary_handlers: Dict[str, TUnaryHandler] = {}
|
|
|
|
|
|
self._ensure_conn_lock = asyncio.Lock()
|
|
|
- self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
|
|
|
- self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
- self.handler_tasks: Dict[CallID, asyncio.Task] = {}
|
|
|
+ self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
|
|
|
+ self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
+ self._handler_tasks: Dict[CallID, asyncio.Task] = {}
|
|
|
|
|
|
self._read_task: Optional[asyncio.Task] = None
|
|
|
self._write_task: Optional[asyncio.Task] = None
|
|
@@ -115,38 +114,38 @@ class ControlClient:
|
|
|
self._write_task.cancel()
|
|
|
|
|
|
async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
- with closing(reader):
|
|
|
- while True:
|
|
|
- resp = p2pd_pb.PCResponse()
|
|
|
- await read_pbmsg_safe(reader, resp)
|
|
|
+ while True:
|
|
|
+ resp = p2pd_pb.PersistentConnectionResponse()
|
|
|
+ await read_pbmsg_safe(reader, resp)
|
|
|
|
|
|
- call_id = uuid.UUID(bytes=resp.callId)
|
|
|
+ call_id = uuid.UUID(bytes=resp.callId)
|
|
|
|
|
|
- if resp.HasField("callUnaryResponse"):
|
|
|
- if call_id in self.pending_calls and resp.callUnaryResponse.HasField("response"):
|
|
|
- self.pending_calls[call_id].set_result(resp.callUnaryResponse.response)
|
|
|
- elif call_id in self.pending_calls and resp.callUnaryResponse.HasField("error"):
|
|
|
- remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode())
|
|
|
- self.pending_calls[call_id].set_exception(remote_exc)
|
|
|
- else:
|
|
|
- logger.debug(f"received unexpected unary call")
|
|
|
+ if resp.HasField("callUnaryResponse"):
|
|
|
+ if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
|
|
|
+ self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
|
|
|
+ elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
|
|
|
+ remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
|
|
|
+ self._pending_calls[call_id].set_exception(remote_exc)
|
|
|
+ else:
|
|
|
+ logger.debug("received unexpected unary call")
|
|
|
|
|
|
- elif resp.HasField("requestHandling"):
|
|
|
- handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
- self.handler_tasks[call_id] = handler_task
|
|
|
+ elif resp.HasField("requestHandling"):
|
|
|
+ handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
+ self._handler_tasks[call_id] = handler_task
|
|
|
|
|
|
- elif call_id in self.handler_tasks and resp.HasField("cancel"):
|
|
|
- self.handler_tasks[call_id].cancel()
|
|
|
+ elif call_id in self._handler_tasks and resp.HasField("cancel"):
|
|
|
+ self._handler_tasks[call_id].cancel()
|
|
|
|
|
|
async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
with closing(writer):
|
|
|
while True:
|
|
|
- msg = await self.pending_messages.get()
|
|
|
+ msg = await self._pending_messages.get()
|
|
|
await write_pbmsg(writer, msg)
|
|
|
|
|
|
async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
|
|
|
if request.proto not in self.unary_handlers:
|
|
|
logger.warning(f"Protocol {request.proto} not supported")
|
|
|
+ return
|
|
|
|
|
|
try:
|
|
|
remote_id = PeerID(request.peer)
|
|
@@ -156,8 +155,13 @@ class ControlClient:
|
|
|
except Exception as e:
|
|
|
response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
|
|
|
|
|
|
- await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
|
|
|
- self.handler_tasks.pop(call_id)
|
|
|
+ await self._pending_messages.put(
|
|
|
+ p2pd_pb.PersistentConnectionRequest(
|
|
|
+ callId=call_id.bytes,
|
|
|
+ unaryResponse=response,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self._handler_tasks.pop(call_id)
|
|
|
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
@@ -172,23 +176,23 @@ class ControlClient:
|
|
|
await handler(stream_info, reader, writer)
|
|
|
|
|
|
async def _send_call_cancel(self, call_id: uuid.UUID):
|
|
|
- await self.pending_messages.put(
|
|
|
- p2pd_pb.PCRequest(
|
|
|
+ await self._pending_messages.put(
|
|
|
+ p2pd_pb.PersistentConnectionRequest(
|
|
|
callId=call_id.bytes,
|
|
|
cancel=p2pd_pb.Cancel(),
|
|
|
),
|
|
|
)
|
|
|
|
|
|
async def _ensure_persistent_conn(self):
|
|
|
- if not self._pers_conn_open:
|
|
|
+ if not self._is_persistent_conn_open:
|
|
|
async with self._ensure_conn_lock:
|
|
|
- if not self._pers_conn_open:
|
|
|
+ if not self._is_persistent_conn_open:
|
|
|
reader, writer = await self.daemon_connector.open_persistent_connection()
|
|
|
|
|
|
self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
|
|
|
self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
|
|
|
|
|
|
- self._pers_conn_open = True
|
|
|
+ self._is_persistent_conn_open = True
|
|
|
|
|
|
async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
|
|
|
await self._ensure_persistent_conn()
|
|
@@ -196,13 +200,13 @@ class ControlClient:
|
|
|
call_id = uuid.uuid4()
|
|
|
|
|
|
add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
|
|
|
- req = p2pd_pb.PCRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
|
|
|
+ req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
|
|
|
|
|
|
if self.unary_handlers.get(proto):
|
|
|
raise ValueError(f"Handler for protocol {proto} already assigned")
|
|
|
self.unary_handlers[proto] = handler
|
|
|
|
|
|
- await self.pending_messages.put(req)
|
|
|
+ await self._pending_messages.put(req)
|
|
|
|
|
|
async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
|
|
|
call_id = uuid.uuid4()
|
|
@@ -211,7 +215,7 @@ class ControlClient:
|
|
|
proto=proto,
|
|
|
data=data,
|
|
|
)
|
|
|
- req = p2pd_pb.PCRequest(
|
|
|
+ req = p2pd_pb.PersistentConnectionRequest(
|
|
|
callId=call_id.bytes,
|
|
|
callUnary=call_unary_req,
|
|
|
)
|
|
@@ -219,16 +223,16 @@ class ControlClient:
|
|
|
await self._ensure_persistent_conn()
|
|
|
|
|
|
try:
|
|
|
- self.pending_calls[call_id] = asyncio.Future()
|
|
|
- await self.pending_messages.put(req)
|
|
|
- return await self.pending_calls[call_id]
|
|
|
+ self._pending_calls[call_id] = asyncio.Future()
|
|
|
+ await self._pending_messages.put(req)
|
|
|
+ return await self._pending_calls[call_id]
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
asyncio.create_task(self._send_call_cancel(call_id))
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
- self.pending_calls.pop(call_id, None)
|
|
|
+ self._pending_calls.pop(call_id, None)
|
|
|
|
|
|
async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
|
|
|
reader, writer = await self.daemon_connector.open_connection()
|