Pārlūkot izejas kodu

Merge branch 'better-messages' into unary-handlers

Denis Mazur 4 gadi atpakaļ
vecāks
revīzija
0e3a4b1f44

+ 9 - 0
hivemind/p2p/p2p_daemon.py

@@ -173,6 +173,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
 
+
+
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
         return await self._client.call_unary_handler(peer_id, handle_name, data)
 
@@ -538,6 +540,13 @@ class P2P:
 
     def _terminate(self) -> None:
         self._alive = False
+
+        if self._client.control._write_task is not None:
+            self._client.control._write_task.cancel()
+
+        if self._client.control._read_task is not None:
+            self._client.control._read_task.cancel()
+
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.wait()

+ 39 - 7
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
 
 import asyncio
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, closing
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 from multiaddr import Multiaddr, protocols
@@ -87,6 +87,10 @@ class ControlClient:
         self._ensure_conn_lock = asyncio.Lock()
         self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self.handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["ControlClient"]:
@@ -101,8 +105,14 @@ class ControlClient:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
 
-        async with server:
-            yield self
+        try:
+            async with server:
+                yield self
+        finally:
+            if self._read_task is not None:
+                self._read_task.cancel()
+            if self._write_task is not None:
+                self._write_task.cancel()
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
         while True:
@@ -121,9 +131,14 @@ class ControlClient:
                     logger.debug(f"received unexpected unary call")
 
             elif resp.HasField("requestHandling"):
-                asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self.handler_tasks[call_id] = handler_task
+
+            elif call_id in self.handler_tasks and resp.HasField("cancel"):
+                self.handler_tasks[call_id].cancel()
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        #with closing(writer):
         while True:
             msg = await self.pending_messages.get()
             await write_pbmsg(writer, msg)
@@ -135,10 +150,12 @@ class ControlClient:
             remote_id = PeerID(request.peer)
             response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
             response = p2pd_pb.CallUnaryResponse(result=response_payload)
+
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
         await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
+        self.handler_tasks.pop(call_id)
 
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
@@ -152,13 +169,23 @@ class ControlClient:
             raise DispatchFailure(e)
         await handler(stream_info, reader, writer)
 
+    async def _send_call_cancel(self, call_id: uuid.UUID):
+        await self.pending_messages.put(
+            p2pd_pb.PCRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
     async def _ensure_persistent_conn(self):
         if not self._pers_conn_open:
             async with self._ensure_conn_lock:
                 if not self._pers_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
-                    asyncio.create_task(self._read_from_persistent_conn(reader))
-                    asyncio.create_task(self._write_to_persistent_conn(writer))
+
+                    self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+                    self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
                     self._pers_conn_open = True
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
@@ -193,8 +220,13 @@ class ControlClient:
             self.pending_calls[call_id] = asyncio.Future()
             await self.pending_messages.put(req)
             return await self.pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            asyncio.create_task(self._send_call_cancel(call_id))
+            raise
+
         finally:
-            await self.pending_calls.pop(call_id)
+            self.pending_calls.pop(call_id, None)
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()

+ 24 - 19
hivemind/proto/p2pd.proto

@@ -35,14 +35,31 @@ message Request {
   optional PSRequest pubsub = 8;
 }
 
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
 // Persistent connection request
 message PCRequest {
   required bytes callId = 1;
+  optional int64 timeout = 2;
 
   oneof message {
-    AddUnaryHandlerRequest addUnaryHandler = 2;
-    CallUnaryRequest callUnary = 3;
-    CallUnaryResponse unaryResponse = 4;
+    AddUnaryHandlerRequest addUnaryHandler = 3;
+    CallUnaryRequest  callUnary = 4;
+    CallUnaryResponse unaryResponse = 5;
+    Cancel cancel = 6;
   }
 }
 
@@ -54,24 +71,10 @@ message PCResponse {
     CallUnaryResponse callUnaryResponse = 2;
     CallUnaryRequest requestHandling = 3;
     DaemonError daemonError = 4;
+    Cancel cancel = 5;
   }
 }
 
-message Response {
-  enum Type {
-    OK    = 0;
-    ERROR = 1;
-  }
-
-  required Type type = 1;
-  optional ErrorResponse error = 2;
-  optional StreamInfo streamInfo = 3;
-  optional IdentifyResponse identify = 4;
-  optional DHTResponse dht = 5;
-  repeated PeerInfo peers = 6;
-  optional PSResponse pubsub = 7;
-}
-
 message IdentifyResponse {
   required bytes id = 1;
   repeated bytes addrs = 2;
@@ -192,7 +195,6 @@ message CallUnaryRequest {
   required bytes peer = 1;
   required string proto = 2;
   required bytes data = 3;
-  optional int64 timeout = 4;
 }
 
 message CallUnaryResponse {
@@ -208,6 +210,9 @@ message DaemonError {
   optional string message = 1;
 }
 
+message Cancel {
+}
+
 message RPCError {
   optional string message = 1;
 }