瀏覽代碼

Merge branch 'better-messages' into unary-handlers

Denis Mazur 4 年之前
父節點
當前提交
0e3a4b1f44
共有 3 個文件被更改,包括 72 次插入26 次删除
  1. 9 0
      hivemind/p2p/p2p_daemon.py
  2. 39 7
      hivemind/p2p/p2p_daemon_bindings/control.py
  3. 24 19
      hivemind/proto/p2pd.proto

+ 9 - 0
hivemind/p2p/p2p_daemon.py

@@ -173,6 +173,8 @@ class P2P:
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
     async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
         return await self._client.add_unary_handler(handle_name, handler)
         return await self._client.add_unary_handler(handle_name, handler)
 
 
+
+
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
         return await self._client.call_unary_handler(peer_id, handle_name, data)
         return await self._client.call_unary_handler(peer_id, handle_name, data)
 
 
@@ -538,6 +540,13 @@ class P2P:
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
         self._alive = False
         self._alive = False
+
+        if self._client.control._write_task is not None:
+            self._client.control._write_task.cancel()
+
+        if self._client.control._read_task is not None:
+            self._client.control._read_task.cancel()
+
         if self._child is not None and self._child.poll() is None:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.terminate()
             self._child.wait()
             self._child.wait()

+ 39 - 7
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -6,7 +6,7 @@ Author: Kevin Mai-Husan Chia
 
 
 import asyncio
 import asyncio
 import uuid
 import uuid
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, closing
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
@@ -87,6 +87,10 @@ class ControlClient:
         self._ensure_conn_lock = asyncio.Lock()
         self._ensure_conn_lock = asyncio.Lock()
         self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
         self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self.handler_tasks: Dict[CallID, asyncio.Task] = {}
+
+        self._read_task: Optional[asyncio.Task] = None
+        self._write_task: Optional[asyncio.Task] = None
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["ControlClient"]:
     async def listen(self) -> AsyncIterator["ControlClient"]:
@@ -101,8 +105,14 @@ class ControlClient:
         else:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
 
 
-        async with server:
-            yield self
+        try:
+            async with server:
+                yield self
+        finally:
+            if self._read_task is not None:
+                self._read_task.cancel()
+            if self._write_task is not None:
+                self._write_task.cancel()
 
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
         while True:
         while True:
@@ -121,9 +131,14 @@ class ControlClient:
                     logger.debug(f"received unexpected unary call")
                     logger.debug(f"received unexpected unary call")
 
 
             elif resp.HasField("requestHandling"):
             elif resp.HasField("requestHandling"):
-                asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self.handler_tasks[call_id] = handler_task
+
+            elif call_id in self.handler_tasks and resp.HasField("cancel"):
+                self.handler_tasks[call_id].cancel()
 
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        #with closing(writer):
         while True:
         while True:
             msg = await self.pending_messages.get()
             msg = await self.pending_messages.get()
             await write_pbmsg(writer, msg)
             await write_pbmsg(writer, msg)
@@ -135,10 +150,12 @@ class ControlClient:
             remote_id = PeerID(request.peer)
             remote_id = PeerID(request.peer)
             response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
             response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
             response = p2pd_pb.CallUnaryResponse(result=response_payload)
             response = p2pd_pb.CallUnaryResponse(result=response_payload)
+
         except Exception as e:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
 
         await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
         await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
+        self.handler_tasks.pop(call_id)
 
 
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
@@ -152,13 +169,23 @@ class ControlClient:
             raise DispatchFailure(e)
             raise DispatchFailure(e)
         await handler(stream_info, reader, writer)
         await handler(stream_info, reader, writer)
 
 
+    async def _send_call_cancel(self, call_id: uuid.UUID):
+        await self.pending_messages.put(
+            p2pd_pb.PCRequest(
+                callId=call_id.bytes,
+                cancel=p2pd_pb.Cancel(),
+            ),
+        )
+
     async def _ensure_persistent_conn(self):
     async def _ensure_persistent_conn(self):
         if not self._pers_conn_open:
         if not self._pers_conn_open:
             async with self._ensure_conn_lock:
             async with self._ensure_conn_lock:
                 if not self._pers_conn_open:
                 if not self._pers_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
                     reader, writer = await self.daemon_connector.open_persistent_connection()
-                    asyncio.create_task(self._read_from_persistent_conn(reader))
-                    asyncio.create_task(self._write_to_persistent_conn(writer))
+
+                    self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+                    self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+
                     self._pers_conn_open = True
                     self._pers_conn_open = True
 
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
@@ -193,8 +220,13 @@ class ControlClient:
             self.pending_calls[call_id] = asyncio.Future()
             self.pending_calls[call_id] = asyncio.Future()
             await self.pending_messages.put(req)
             await self.pending_messages.put(req)
             return await self.pending_calls[call_id]
             return await self.pending_calls[call_id]
+
+        except asyncio.CancelledError:
+            asyncio.create_task(self._send_call_cancel(call_id))
+            raise
+
         finally:
         finally:
-            await self.pending_calls.pop(call_id)
+            self.pending_calls.pop(call_id, None)
 
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()

+ 24 - 19
hivemind/proto/p2pd.proto

@@ -35,14 +35,31 @@ message Request {
   optional PSRequest pubsub = 8;
   optional PSRequest pubsub = 8;
 }
 }
 
 
+message Response {
+  enum Type {
+    OK    = 0;
+    ERROR = 1;
+  }
+
+  required Type type = 1;
+  optional ErrorResponse error = 2;
+  optional StreamInfo streamInfo = 3;
+  optional IdentifyResponse identify = 4;
+  optional DHTResponse dht = 5;
+  repeated PeerInfo peers = 6;
+  optional PSResponse pubsub = 7;
+}
+
 // Persistent connection request
 // Persistent connection request
 message PCRequest {
 message PCRequest {
   required bytes callId = 1;
   required bytes callId = 1;
+  optional int64 timeout = 2;
 
 
   oneof message {
   oneof message {
-    AddUnaryHandlerRequest addUnaryHandler = 2;
-    CallUnaryRequest callUnary = 3;
-    CallUnaryResponse unaryResponse = 4;
+    AddUnaryHandlerRequest addUnaryHandler = 3;
+    CallUnaryRequest  callUnary = 4;
+    CallUnaryResponse unaryResponse = 5;
+    Cancel cancel = 6;
   }
   }
 }
 }
 
 
@@ -54,24 +71,10 @@ message PCResponse {
     CallUnaryResponse callUnaryResponse = 2;
     CallUnaryResponse callUnaryResponse = 2;
     CallUnaryRequest requestHandling = 3;
     CallUnaryRequest requestHandling = 3;
     DaemonError daemonError = 4;
     DaemonError daemonError = 4;
+    Cancel cancel = 5;
   }
   }
 }
 }
 
 
-message Response {
-  enum Type {
-    OK    = 0;
-    ERROR = 1;
-  }
-
-  required Type type = 1;
-  optional ErrorResponse error = 2;
-  optional StreamInfo streamInfo = 3;
-  optional IdentifyResponse identify = 4;
-  optional DHTResponse dht = 5;
-  repeated PeerInfo peers = 6;
-  optional PSResponse pubsub = 7;
-}
-
 message IdentifyResponse {
 message IdentifyResponse {
   required bytes id = 1;
   required bytes id = 1;
   repeated bytes addrs = 2;
   repeated bytes addrs = 2;
@@ -192,7 +195,6 @@ message CallUnaryRequest {
   required bytes peer = 1;
   required bytes peer = 1;
   required string proto = 2;
   required string proto = 2;
   required bytes data = 3;
   required bytes data = 3;
-  optional int64 timeout = 4;
 }
 }
 
 
 message CallUnaryResponse {
 message CallUnaryResponse {
@@ -208,6 +210,9 @@ message DaemonError {
   optional string message = 1;
   optional string message = 1;
 }
 }
 
 
+message Cancel {
+}
+
 message RPCError {
 message RPCError {
   optional string message = 1;
   optional string message = 1;
 }
 }