Denis Mazur 4 years ago
parent
commit
38a182cde6
2 changed files with 27 additions and 6 deletions
  1. 9 0
      hivemind/p2p/p2p_daemon.py
  2. 18 6
      hivemind/p2p/p2p_daemon_bindings/control.py

+ 9 - 0
hivemind/p2p/p2p_daemon.py

@@ -173,6 +173,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
 
+
+
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
         return await self._client.call_unary_handler(peer_id, handle_name, data)
 
@@ -538,6 +540,13 @@ class P2P:
 
     def _terminate(self) -> None:
         self._alive = False
+
+        if self._client.control._write_task is not None:
+            self._client.control._write_task.cancel()
+
+        if self._client.control._read_task is not None:
+            self._client.control._read_task.cancel()
+
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.wait()

+ 18 - 6
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
 
 import asyncio
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, closing
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 from multiaddr import Multiaddr, protocols
@@ -89,6 +89,9 @@ class ControlClient:
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self.handler_tasks: Dict[CallID, asyncio.Task] = {}
 
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["ControlClient"]:
         proto_code = parse_conn_protocol(self.listen_maddr)
@@ -102,8 +105,14 @@ class ControlClient:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
 
-        async with server:
-            yield self
+        try:
+            async with server:
+                yield self
+        finally:
+            if self._read_task is not None:
+                self._read_task.cancel()
+            if self._write_task is not None:
+                self._write_task.cancel()
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
         while True:
@@ -129,6 +138,7 @@ class ControlClient:
                 self.handler_tasks[call_id].cancel()
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        #with closing(writer):
         while True:
             msg = await self.pending_messages.get()
             await write_pbmsg(writer, msg)
@@ -172,8 +182,10 @@ class ControlClient:
             async with self._ensure_conn_lock:
                 if not self._pers_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
-                    asyncio.create_task(self._read_from_persistent_conn(reader))
-                    asyncio.create_task(self._write_to_persistent_conn(writer))
+
+                    self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+                    self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
                     self._pers_conn_open = True
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
@@ -214,7 +226,7 @@ class ControlClient:
             raise
 
         finally:
-            await self.pending_calls.pop(call_id)
+            self.pending_calls.pop(call_id, None)
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()