소스 검색

add unary handler support to p2p control class

Denis Mazur 4 년 전
부모
커밋
3f7c2e381c
4개의 변경된 파일155개의 추가작업 그리고 16개의 파일을 삭제
  1. 1 0
      .gitignore
  2. 1 1
      hivemind/p2p/p2p_daemon.py
  3. 112 1
      hivemind/p2p/p2p_daemon_bindings/control.py
  4. 41 14
      hivemind/proto/p2pd.proto

+ 1 - 0
.gitignore

@@ -54,6 +54,7 @@ coverage.xml
 .project
 .pydevproject
 .idea
+.vscode
 .ipynb_checkpoints
 
 # Rope

+ 1 - 1
hivemind/p2p/p2p_daemon.py

@@ -6,7 +6,7 @@ from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
 from subprocess import Popen
-from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 
 from multiaddr import Multiaddr
 

+ 112 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -5,8 +5,9 @@ Author: Kevin Mai-Husan Chia
 """
 
 import asyncio
+import uuid
 from contextlib import asynccontextmanager
-from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Sequence, Tuple
+from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
 
 from multiaddr import Multiaddr, protocols
 
@@ -54,6 +55,20 @@ class DaemonConnector:
         else:
             raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
 
+    async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
+        """
+        Open connection to daemon and upgrade it to a persistent one
+        """
+        reader, writer = await self.open_connection()
+        req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
+        await write_pbmsg(writer, req)
+
+        return reader, writer
+
+
+TUnaryHandler = Callable[[bytes], bytes]
+CallID = uuid.UUID
+
 
 class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
@@ -65,6 +80,54 @@ class ControlClient:
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
 
+        # persistent connection readers & writers
+        self._pers_conn_open: bool = False
+        self.unary_handlers: Dict[str, TUnaryHandler] = {}
+
+        self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
+        self.pending_calls: Dict[CallID, asyncio.Future] = {}
+
+    async def read_from_persistent_conn(self, reader: asyncio.StreamReader):
+        while True:
+            resp: p2pd_pb.Response = p2pd_pb.Response()  # type: ignore
+            await read_pbmsg_safe(reader, resp)
+
+            if resp.callUnaryResponse:
+                call_id = uuid.UUID(bytes=resp.callUnaryResponse.callId)
+
+                if call_id in self.pending_calls and resp.data:
+                    self.pending_calls[call_id].set_result(call_id)
+                elif call_id in self.pending_calls and resp.error:
+                    remote_exc = RemoteException(str(resp.error))
+                    self.pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug(f"received unexpected unary call")
+
+            elif resp.requestHandling:
+                # asyncio.create_task(self.read)
+                pass
+
+    async def _handle_persistent_request(self, request):
+        assert request.proto in self.unary_handlers
+
+        try:
+            response_payload: bytes = await self.unary_protocols[request.protocol](request.payload)
+            response = p2pd_pb.CallUnaryResponse(
+                call_id=request.call_id,
+                data=response_payload)
+        except Exception as e:
+            response = p2pd_pb.CallUnaryResponse(
+                call_id=request.call_id,
+                error=repr(e))
+
+        await self.pending_messages.put(
+            p2pd_pb.Request(type=p2pd_pb.Request.UNARY_RESPONSE, response=response))
+
+    async def write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        while True:
+            msg = await self.pending_messages.get()
+            await write_pbmsg(writer, msg)
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -93,6 +156,48 @@ class ControlClient:
         async with server:
             yield self
 
+    async def _ensure_persistent_conn(self):
+        if not self._pers_conn_open:
+            reader, writer = await self.daemon_connector.open_persistent_connection()
+            asyncio.create_task(self.read_from_persistent_conn(reader))
+            asyncio.create_task(self.write_to_persistent_conn(writer))
+
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self._ensure_persistent_conn()
+
+        add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.ADD_UNARY_HANDLER,
+            addUnaryHandler=add_unary_handler_req,
+        )
+        await self.pending_messages.put(req)
+
+        if self.unary_handlers.get(proto):
+            raise ValueError(f"Handler for protocol {proto} already assigned")
+        self.unary_handlers[proto] = handler
+
+    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        call_id = uuid.uuid4()
+        call_unary_req = p2pd_pb.CallUnaryRequest(
+            peer=peer_id.to_bytes(),
+            proto=proto,
+            data=data,
+            callId=call_id,
+        )
+        req = p2pd_pb.Request(
+            type=p2pd_pb.Request.CALL_UNARY,
+            callUnary=call_unary_req,
+        )
+
+        await self._ensure_persistent_conn()
+
+        try:
+            self.pending_calls[call_id] = asyncio.Future()
+            await self.pending_messages.put(req)
+            return await self.pending_calls[call_id]
+        finally:
+            await self.pending_calls.pop(call_id)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
@@ -179,3 +284,9 @@ class ControlClient:
 
         # if success, add the handler to the dict
         self.handlers[proto] = handler_cb
+
+
+class RemoteException(Exception):
+    """
+    Raised if remote handled a request with an exception
+    """

+ 41 - 14
hivemind/proto/p2pd.proto

@@ -7,18 +7,23 @@ syntax = "proto2";
 package p2pclient.p2pd.pb;
 
 message Request {
-  enum Type {
-    IDENTIFY       = 0;
-    CONNECT        = 1;
-    STREAM_OPEN    = 2;
-    STREAM_HANDLER = 3;
-    DHT            = 4;
-    LIST_PEERS     = 5;
-    CONNMANAGER    = 6;
-    DISCONNECT     = 7;
-    PUBSUB         = 8;
-  }
+   enum Type {
+    IDENTIFY                 = 0;
+    CONNECT                  = 1;
+    STREAM_OPEN              = 2;
+    STREAM_HANDLER           = 3;
+    DHT                      = 4;
+    LIST_PEERS               = 5;
+    CONNMANAGER              = 6;
+    DISCONNECT               = 7;      
+    PUBSUB                   = 8;
+
+    PERSISTENT_CONN_UPGRADE  = 9;
+    CALL_UNARY              = 10;
+    ADD_UNARY_HANDLER       = 11;
+    SEND_RESPONSE_TO_REMOTE = 12;
 
+  }
   required Type type = 1;
 
   optional ConnectRequest connect = 2;
@@ -28,6 +33,10 @@ message Request {
   optional ConnManagerRequest connManager = 6;
   optional DisconnectRequest disconnect = 7;
   optional PSRequest pubsub = 8;
+
+  optional CallUnaryRequest callUnary = 9;
+  optional AddUnaryHandlerRequest addUnaryHandler = 10;
+  optional CallUnaryResponse sendResponseToRemote = 11;
 }
 
 message Response {
@@ -43,6 +52,9 @@ message Response {
   optional DHTResponse dht = 5;
   repeated PeerInfo peers = 6;
   optional PSResponse pubsub = 7;
+
+  optional CallUnaryResponse callUnaryResponse = 8;
+  optional CallUnaryRequest requestHandling = 9;
 }
 
 message IdentifyResponse {
@@ -147,8 +159,9 @@ message PSRequest {
   optional bytes data = 3;
 }
 
+
 message PSMessage {
-  optional bytes from_id = 1;
+  optional bytes from = 1;
   optional bytes data = 2;
   optional bytes seqno = 3;
   repeated string topicIDs = 4;
@@ -161,6 +174,20 @@ message PSResponse {
   repeated bytes peerIDs = 2;
 }
 
-message RPCError {
-  optional string message = 1;
+message CallUnaryRequest {
+  required bytes peer = 1;
+  required string proto = 2;
+  required bytes data = 3;
+  required int64 callId = 4;
+  optional int64 timeout = 5;
+}
+
+message CallUnaryResponse {
+  required int64 callId = 1;
+  optional bytes result = 2;
+  optional bytes error = 3;
+}
+
+message AddUnaryHandlerRequest {
+  required string proto = 1;
 }