|
@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
|
|
|
|
|
|
import asyncio
|
|
|
import uuid
|
|
|
-from contextlib import asynccontextmanager
|
|
|
+from contextlib import asynccontextmanager, closing
|
|
|
from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
|
|
|
|
|
|
from multiaddr import Multiaddr, protocols
|
|
@@ -89,6 +89,9 @@ class ControlClient:
|
|
|
self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
|
|
|
self.handler_tasks: Dict[CallID, asyncio.Task] = {}
|
|
|
|
|
|
+ self._read_task: Optional[asyncio.Task] = None
|
|
|
+ self._write_task: Optional[asyncio.Task] = None
|
|
|
+
|
|
|
@asynccontextmanager
|
|
|
async def listen(self) -> AsyncIterator["ControlClient"]:
|
|
|
proto_code = parse_conn_protocol(self.listen_maddr)
|
|
@@ -102,8 +105,14 @@ class ControlClient:
|
|
|
else:
|
|
|
raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
|
|
|
|
|
|
- async with server:
|
|
|
- yield self
|
|
|
+ try:
|
|
|
+ async with server:
|
|
|
+ yield self
|
|
|
+ finally:
|
|
|
+ if self._read_task is not None:
|
|
|
+ self._read_task.cancel()
|
|
|
+ if self._write_task is not None:
|
|
|
+ self._write_task.cancel()
|
|
|
|
|
|
async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
|
|
|
while True:
|
|
@@ -129,6 +138,7 @@ class ControlClient:
|
|
|
self.handler_tasks[call_id].cancel()
|
|
|
|
|
|
async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
|
|
|
+ #with closing(writer):
|
|
|
while True:
|
|
|
msg = await self.pending_messages.get()
|
|
|
await write_pbmsg(writer, msg)
|
|
@@ -172,8 +182,10 @@ class ControlClient:
|
|
|
async with self._ensure_conn_lock:
|
|
|
if not self._pers_conn_open:
|
|
|
reader, writer = await self.daemon_connector.open_persistent_connection()
|
|
|
- asyncio.create_task(self._read_from_persistent_conn(reader))
|
|
|
- asyncio.create_task(self._write_to_persistent_conn(writer))
|
|
|
+
|
|
|
+ self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
|
|
|
+ self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
|
|
|
+
|
|
|
self._pers_conn_open = True
|
|
|
|
|
|
async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
|
|
@@ -214,7 +226,7 @@ class ControlClient:
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
- await self.pending_calls.pop(call_id)
|
|
|
+ self.pending_calls.pop(call_id, None)
|
|
|
|
|
|
async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
|
|
|
reader, writer = await self.daemon_connector.open_connection()
|