|
@@ -87,6 +87,22 @@ class ControlClient:
|
|
|
self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
|
|
|
self.pending_calls: Dict[CallID, asyncio.Future] = {}
|
|
|
|
|
|
+ @asynccontextmanager
|
|
|
+ async def listen(self) -> AsyncIterator["ControlClient"]:
|
|
|
+ proto_code = parse_conn_protocol(self.listen_maddr)
|
|
|
+ if proto_code == protocols.P_UNIX:
|
|
|
+ listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
|
|
|
+ server = await asyncio.start_unix_server(self._handler, path=listen_path)
|
|
|
+ elif proto_code == protocols.P_IP4:
|
|
|
+ host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
|
|
|
+ port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
|
|
|
+ server = await asyncio.start_server(self._handler, port=port, host=host)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
|
|
|
+
|
|
|
+ async with server:
|
|
|
+ yield self
|
|
|
+
|
|
|
async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
while True:
|
|
|
resp: p2pd_pb.Response = p2pd_pb.Response() # type: ignore
|
|
@@ -106,6 +122,11 @@ class ControlClient:
|
|
|
elif resp.requestHandling:
|
|
|
asyncio.create_task(self._handle_persistent_request(resp.requestHandling))
|
|
|
pass
|
|
|
+
|
|
|
+ async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
+ while True:
|
|
|
+ msg = await self.pending_messages.get()
|
|
|
+ await write_pbmsg(writer, msg)
|
|
|
|
|
|
async def _handle_persistent_request(self, request):
|
|
|
assert request.proto in self.unary_handlers
|
|
@@ -123,11 +144,6 @@ class ControlClient:
|
|
|
await self.pending_messages.put(
|
|
|
p2pd_pb.Request(type=p2pd_pb.Request.UNARY_RESPONSE, response=response))
|
|
|
|
|
|
- async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
- while True:
|
|
|
- msg = await self.pending_messages.get()
|
|
|
- await write_pbmsg(writer, msg)
|
|
|
-
|
|
|
async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
|
pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
|
|
|
await read_pbmsg_safe(reader, pb_stream_info)
|
|
@@ -140,22 +156,6 @@ class ControlClient:
|
|
|
raise DispatchFailure(e)
|
|
|
await handler(stream_info, reader, writer)
|
|
|
|
|
|
- @asynccontextmanager
|
|
|
- async def listen(self) -> AsyncIterator["ControlClient"]:
|
|
|
- proto_code = parse_conn_protocol(self.listen_maddr)
|
|
|
- if proto_code == protocols.P_UNIX:
|
|
|
- listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
|
|
|
- server = await asyncio.start_unix_server(self._handler, path=listen_path)
|
|
|
- elif proto_code == protocols.P_IP4:
|
|
|
- host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
|
|
|
- port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
|
|
|
- server = await asyncio.start_server(self._handler, port=port, host=host)
|
|
|
- else:
|
|
|
- raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
|
|
|
-
|
|
|
- async with server:
|
|
|
- yield self
|
|
|
-
|
|
|
async def _ensure_persistent_conn(self):
|
|
|
if not self._pers_conn_open:
|
|
|
reader, writer = await self.daemon_connector.open_persistent_connection()
|