|
@@ -85,7 +85,7 @@ class ControlClient:
|
|
|
self.unary_handlers: Dict[str, TUnaryHandler] = {}
|
|
|
|
|
|
self._ensure_conn_lock = asyncio.Lock()
|
|
|
- self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
|
|
|
+ self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
|
|
|
self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
|
|
|
@asynccontextmanager
|
|
@@ -106,12 +106,12 @@ class ControlClient:
|
|
|
|
|
|
async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
while True:
|
|
|
- resp = p2pd_pb.Response()
|
|
|
+ resp = p2pd_pb.PCResponse()
|
|
|
await read_pbmsg_safe(reader, resp)
|
|
|
|
|
|
- if resp.HasField("callUnaryResponse"):
|
|
|
- call_id = uuid.UUID(bytes=resp.callUnaryResponse.callId)
|
|
|
+ call_id = uuid.UUID(bytes=resp.callId)
|
|
|
|
|
|
+ if resp.HasField("callUnaryResponse"):
|
|
|
if call_id in self.pending_calls and resp.callUnaryResponse.HasField("result"):
|
|
|
self.pending_calls[call_id].set_result(resp.callUnaryResponse.result)
|
|
|
elif call_id in self.pending_calls and resp.callUnaryResponse.HasField("error"):
|
|
@@ -121,27 +121,24 @@ class ControlClient:
|
|
|
logger.debug(f"received unexpected unary call")
|
|
|
|
|
|
elif resp.HasField("requestHandling"):
|
|
|
- asyncio.create_task(self._handle_persistent_request(resp.requestHandling))
|
|
|
- pass
|
|
|
+ asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
|
|
|
|
|
|
async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
while True:
|
|
|
msg = await self.pending_messages.get()
|
|
|
await write_pbmsg(writer, msg)
|
|
|
|
|
|
- async def _handle_persistent_request(self, request):
|
|
|
+ async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
|
|
|
assert request.proto in self.unary_handlers
|
|
|
|
|
|
try:
|
|
|
remote_id = PeerID(request.peer)
|
|
|
response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
|
|
|
- response = p2pd_pb.CallUnaryResponse(callId=request.callId, result=response_payload)
|
|
|
+ response = p2pd_pb.CallUnaryResponse(result=response_payload)
|
|
|
except Exception as e:
|
|
|
- response = p2pd_pb.CallUnaryResponse(callId=request.callId, error=repr(e).encode())
|
|
|
+ response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
|
|
|
|
|
|
- await self.pending_messages.put(
|
|
|
- p2pd_pb.Request(type=p2pd_pb.Request.SEND_RESPONSE_TO_REMOTE, sendResponseToRemote=response)
|
|
|
- )
|
|
|
+ await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
|
|
|
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
@@ -167,11 +164,11 @@ class ControlClient:
|
|
|
async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
|
|
|
await self._ensure_persistent_conn()
|
|
|
|
|
|
+ call_id = uuid.uuid4()
|
|
|
+
|
|
|
add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
|
|
|
- req = p2pd_pb.Request(
|
|
|
- type=p2pd_pb.Request.ADD_UNARY_HANDLER,
|
|
|
- addUnaryHandler=add_unary_handler_req,
|
|
|
- )
|
|
|
+ req = p2pd_pb.PCRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
|
|
|
+
|
|
|
if self.unary_handlers.get(proto):
|
|
|
raise ValueError(f"Handler for protocol {proto} already assigned")
|
|
|
self.unary_handlers[proto] = handler
|
|
@@ -184,10 +181,9 @@ class ControlClient:
|
|
|
peer=peer_id.to_bytes(),
|
|
|
proto=proto,
|
|
|
data=data,
|
|
|
- callId=call_id.bytes,
|
|
|
)
|
|
|
- req = p2pd_pb.Request(
|
|
|
- type=p2pd_pb.Request.CALL_UNARY,
|
|
|
+ req = p2pd_pb.PCRequest(
|
|
|
+ callId=call_id.bytes,
|
|
|
callUnary=call_unary_req,
|
|
|
)
|
|
|
|