Denis Mazur преди 4 години
родител
ревизия
38a182cde6
променени са 2 файла, в които са добавени 27 реда и са изтрити 6 реда
  1. 9 0
      hivemind/p2p/p2p_daemon.py
  2. 18 6
      hivemind/p2p/p2p_daemon_bindings/control.py

+ 9 - 0
hivemind/p2p/p2p_daemon.py

@@ -173,6 +173,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
         return await self._client.add_unary_handler(handle_name, handler)
 
 
+
+
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
         return await self._client.call_unary_handler(peer_id, handle_name, data)
         return await self._client.call_unary_handler(peer_id, handle_name, data)
 
 
@@ -538,6 +540,13 @@ class P2P:
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
         self._alive = False
         self._alive = False
+
+        if self._client.control._write_task is not None:
+            self._client.control._write_task.cancel()
+
+        if self._client.control._read_task is not None:
+            self._client.control._read_task.cancel()
+
         if self._child is not None and self._child.poll() is None:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.terminate()
             self._child.wait()
             self._child.wait()

+ 18 - 6
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
 
 
 import asyncio
 import asyncio
 import uuid
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, closing
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
@@ -89,6 +89,9 @@ class ControlClient:
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self.handler_tasks: Dict[CallID, asyncio.Task] = {}
         self.handler_tasks: Dict[CallID, asyncio.Task] = {}
 
 
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
+
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["ControlClient"]:
     async def listen(self) -> AsyncIterator["ControlClient"]:
         proto_code = parse_conn_protocol(self.listen_maddr)
         proto_code = parse_conn_protocol(self.listen_maddr)
@@ -102,8 +105,14 @@ class ControlClient:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
 
 
-        async with server:
-            yield self
+        try:
+            async with server:
+                yield self
+        finally:
+            if self._read_task is not None:
+                self._read_task.cancel()
+            if self._write_task is not None:
+                self._write_task.cancel()
 
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
         while True:
         while True:
@@ -129,6 +138,7 @@ class ControlClient:
                 self.handler_tasks[call_id].cancel()
                 self.handler_tasks[call_id].cancel()
 
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        #with closing(writer):
         while True:
         while True:
             msg = await self.pending_messages.get()
             msg = await self.pending_messages.get()
             await write_pbmsg(writer, msg)
             await write_pbmsg(writer, msg)
@@ -172,8 +182,10 @@ class ControlClient:
             async with self._ensure_conn_lock:
             async with self._ensure_conn_lock:
                 if not self._pers_conn_open:
                 if not self._pers_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
                     reader, writer = await self.daemon_connector.open_persistent_connection()
-                    asyncio.create_task(self._read_from_persistent_conn(reader))
-                    asyncio.create_task(self._write_to_persistent_conn(writer))
+
+                    self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+                    self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
                     self._pers_conn_open = True
                     self._pers_conn_open = True
 
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
@@ -214,7 +226,7 @@ class ControlClient:
             raise
             raise
 
 
         finally:
         finally:
-            await self.pending_calls.pop(call_id)
+            self.pending_calls.pop(call_id, None)
 
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()