Procházet zdrojové kódy

add cancellation support for unary handlers

Denis Mazur před 4 roky
rodič
revize
86d01c8df0

+ 22 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -87,6 +87,7 @@ class ControlClient:
         self._ensure_conn_lock = asyncio.Lock()
         self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self.handler_tasks: Dict[CallID, asyncio.Task] = {}
 
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["ControlClient"]:
@@ -121,7 +122,12 @@ class ControlClient:
                     logger.debug(f"received unexpected unary call")
 
             elif resp.HasField("requestHandling"):
-                asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                # TODO: fix race condition at the of ._handle_persistent_request(...)
+                self.handler_tasks[call_id] = handler_task
+
+            elif call_id in self.handler_tasks and resp.HasField("cancel"):
+                self.handler_tasks[call_id].cancel()
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
         while True:
@@ -135,10 +141,12 @@ class ControlClient:
             remote_id = PeerID(request.peer)
             response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
             response = p2pd_pb.CallUnaryResponse(result=response_payload)
+
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
         await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
+        self.handler_tasks.pop(call_id)
 
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
@@ -152,6 +160,14 @@ class ControlClient:
             raise DispatchFailure(e)
         await handler(stream_info, reader, writer)
 
+    async def _send_call_cancel(self, call_id: uuid.UUID):
+        await self.pending_messages.put(
+            p2pd_pb.PCRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
     async def _ensure_persistent_conn(self):
         if not self._pers_conn_open:
             async with self._ensure_conn_lock:
@@ -193,6 +209,11 @@ class ControlClient:
             self.pending_calls[call_id] = asyncio.Future()
             await self.pending_messages.put(req)
             return await self.pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            await self._send_call_cancel(call_id)
+            raise
+
         finally:
             await self.pending_calls.pop(call_id)
 

+ 24 - 19
hivemind/proto/p2pd.proto

@@ -35,14 +35,31 @@ message Request {
   optional PSRequest pubsub = 8;
 }
 
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
 // Persistent connection request
 message PCRequest {
   required bytes callId = 1;
+  optional int64 timeout = 2;
 
   oneof message {
-    AddUnaryHandlerRequest addUnaryHandler = 2;
-    CallUnaryRequest callUnary = 3;
-    CallUnaryResponse unaryResponse = 4;
+    AddUnaryHandlerRequest addUnaryHandler = 3;
+    CallUnaryRequest  callUnary = 4;
+    CallUnaryResponse unaryResponse = 5;
+    Cancel cancel = 6;
   }
 }
 
@@ -54,24 +71,10 @@ message PCResponse {
     CallUnaryResponse callUnaryResponse = 2;
     CallUnaryRequest requestHandling = 3;
     DaemonError daemonError = 4;
+    Cancel cancel = 5;
   }
 }
 
-message Response {
-  enum Type {
-    OK    = 0;
-    ERROR = 1;
-  }
-
-  required Type type = 1;
-  optional ErrorResponse error = 2;
-  optional StreamInfo streamInfo = 3;
-  optional IdentifyResponse identify = 4;
-  optional DHTResponse dht = 5;
-  repeated PeerInfo peers = 6;
-  optional PSResponse pubsub = 7;
-}
-
 message IdentifyResponse {
   required bytes id = 1;
   repeated bytes addrs = 2;
@@ -192,7 +195,6 @@ message CallUnaryRequest {
   required bytes peer = 1;
   required string proto = 2;
   required bytes data = 3;
-  optional int64 timeout = 4;
 }
 
 message CallUnaryResponse {
@@ -208,6 +210,9 @@ message DaemonError {
   optional string message = 1;
 }
 
+message Cancel {
+}
+
 message RPCError {
   optional string message = 1;
 }