ソースを参照

fix suggestions
Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>

Denis Mazur 4 年 前
コミット
96c59b9c1e

+ 21 - 18
hivemind/p2p/p2p_daemon.py

@@ -4,11 +4,11 @@ import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
-from google.protobuf.message import Message
 from importlib.resources import path
 from importlib.resources import path
 from subprocess import Popen
 from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind.hivemind_cli as cli
 import hivemind.hivemind_cli as cli
@@ -170,12 +170,6 @@ class P2P:
 
 
         return self
         return self
 
 
-    async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
-        return await self._client.add_unary_handler(handle_name, handler)
-
-    async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
-        return await self._client.call_unary_handler(peer_id, handle_name, data)
-
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
         for try_number in range(ping_n_attempts):
             await asyncio.sleep(ping_delay * (2 ** try_number))
             await asyncio.sleep(ping_delay * (2 ** try_number))
@@ -282,7 +276,7 @@ class P2P:
 
 
     @staticmethod
     @staticmethod
     async def receive_protobuf(
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Message, reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
         if msg_type == P2P.MESSAGE_MARKER:
@@ -303,7 +297,7 @@ class P2P:
         self,
         self,
         name: str,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
         max_prefetch: int = 5,
         max_prefetch: int = 5,
     ) -> None:
     ) -> None:
         """
         """
@@ -367,7 +361,7 @@ class P2P:
         await self._client.stream_handler(name, _handle_stream)
         await self._client.stream_handler(name, _handle_stream)
 
 
     async def _iterate_protobuf_stream_handler(
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
     ) -> TOutputStream:
     ) -> TOutputStream:
         _, reader, writer = await self._client.stream_open(peer_id, (name,))
         _, reader, writer = await self._client.stream_open(peer_id, (name,))
 
 
@@ -399,7 +393,7 @@ class P2P:
         handler: Callable[
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
         *,
         *,
         stream_input: bool = False,
         stream_input: bool = False,
         stream_output: bool = False,
         stream_output: bool = False,
@@ -410,7 +404,7 @@ class P2P:
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
         """
         """
 
 
-        if not (stream_input or stream_output):
+        if not stream_input and not stream_output:
             await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
             await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
             return
             return
 
 
@@ -439,8 +433,17 @@ class P2P:
         self,
         self,
         handle_name: str,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
     ) -> None:
     ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
         async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
         async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
             input_serialized = input_protobuf_type.FromString(request)
             input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
             context = P2PContext(
@@ -452,14 +455,14 @@ class P2P:
             response = await handler(input_serialized, context)
             response = await handler(input_serialized, context)
             return response.SerializeToString()
             return response.SerializeToString()
 
 
-        await self.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler)
 
 
     async def call_protobuf_handler(
     async def call_protobuf_handler(
         self,
         self,
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
 
 
         if not isinstance(input, AsyncIterableABC):
         if not isinstance(input, AsyncIterableABC):
@@ -479,10 +482,10 @@ class P2P:
         peer_id: PeerID,
         peer_id: PeerID,
         handle_name: str,
         handle_name: str,
         input: TInputProtobuf,
         input: TInputProtobuf,
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> Awaitable[TOutputProtobuf]:
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
         serialized_input = input.SerializeToString()
-        response = await self.call_unary_handler(peer_id, handle_name, serialized_input)
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type().FromString(response)
         return output_protobuf_type().FromString(response)
 
 
     def iterate_protobuf_handler(
     def iterate_protobuf_handler(
@@ -490,7 +493,7 @@ class P2P:
         peer_id: PeerID,
         peer_id: PeerID,
         name: str,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> TOutputStream:
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)

+ 42 - 38
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -80,14 +80,13 @@ class ControlClient:
         self.daemon_connector = daemon_connector
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
         self.handlers: Dict[str, StreamHandler] = {}
 
 
-        # persistent connection readers & writers
-        self._pers_conn_open: bool = False
+        self._is_persistent_conn_open: bool = False
         self.unary_handlers: Dict[str, TUnaryHandler] = {}
         self.unary_handlers: Dict[str, TUnaryHandler] = {}
 
 
         self._ensure_conn_lock = asyncio.Lock()
         self._ensure_conn_lock = asyncio.Lock()
-        self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
-        self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
-        self.handler_tasks: Dict[CallID, asyncio.Task] = {}
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
 
 
         self._read_task: Optional[asyncio.Task] = None
         self._read_task: Optional[asyncio.Task] = None
         self._write_task: Optional[asyncio.Task] = None
         self._write_task: Optional[asyncio.Task] = None
@@ -115,38 +114,38 @@ class ControlClient:
                 self._write_task.cancel()
                 self._write_task.cancel()
 
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
-        with closing(reader):
-            while True:
-                resp = p2pd_pb.PCResponse()
-                await read_pbmsg_safe(reader, resp)
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            await read_pbmsg_safe(reader, resp)
 
 
-                call_id = uuid.UUID(bytes=resp.callId)
+            call_id = uuid.UUID(bytes=resp.callId)
 
 
-                if resp.HasField("callUnaryResponse"):
-                    if call_id in self.pending_calls and resp.callUnaryResponse.HasField("response"):
-                        self.pending_calls[call_id].set_result(resp.callUnaryResponse.response)
-                    elif call_id in self.pending_calls and resp.callUnaryResponse.HasField("error"):
-                        remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode())
-                        self.pending_calls[call_id].set_exception(remote_exc)
-                    else:
-                        logger.debug(f"received unexpected unary call")
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug("received unexpected unary call")
 
 
-                elif resp.HasField("requestHandling"):
-                    handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
-                    self.handler_tasks[call_id] = handler_task
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
 
 
-                elif call_id in self.handler_tasks and resp.HasField("cancel"):
-                    self.handler_tasks[call_id].cancel()
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
 
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
         with closing(writer):
         with closing(writer):
             while True:
             while True:
-                msg = await self.pending_messages.get()
+                msg = await self._pending_messages.get()
                 await write_pbmsg(writer, msg)
                 await write_pbmsg(writer, msg)
 
 
     async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
     async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
         if request.proto not in self.unary_handlers:
         if request.proto not in self.unary_handlers:
             logger.warning(f"Protocol {request.proto} not supported")
             logger.warning(f"Protocol {request.proto} not supported")
+            return
 
 
         try:
         try:
             remote_id = PeerID(request.peer)
             remote_id = PeerID(request.peer)
@@ -156,8 +155,13 @@ class ControlClient:
         except Exception as e:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
 
-        await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
-        self.handler_tasks.pop(call_id)
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
 
 
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
@@ -172,23 +176,23 @@ class ControlClient:
         await handler(stream_info, reader, writer)
         await handler(stream_info, reader, writer)
 
 
     async def _send_call_cancel(self, call_id: uuid.UUID):
     async def _send_call_cancel(self, call_id: uuid.UUID):
-        await self.pending_messages.put(
-            p2pd_pb.PCRequest(
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
                 callId=call_id.bytes,
                 cancel=p2pd_pb.Cancel(),
                 cancel=p2pd_pb.Cancel(),
             ),
             ),
         )
         )
 
 
     async def _ensure_persistent_conn(self):
     async def _ensure_persistent_conn(self):
-        if not self._pers_conn_open:
+        if not self._is_persistent_conn_open:
             async with self._ensure_conn_lock:
             async with self._ensure_conn_lock:
-                if not self._pers_conn_open:
+                if not self._is_persistent_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
                     reader, writer = await self.daemon_connector.open_persistent_connection()
 
 
                     self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
                     self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
                     self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
                     self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
 
-                    self._pers_conn_open = True
+                    self._is_persistent_conn_open = True
 
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
         await self._ensure_persistent_conn()
         await self._ensure_persistent_conn()
@@ -196,13 +200,13 @@ class ControlClient:
         call_id = uuid.uuid4()
         call_id = uuid.uuid4()
 
 
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
-        req = p2pd_pb.PCRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
 
         if self.unary_handlers.get(proto):
         if self.unary_handlers.get(proto):
             raise ValueError(f"Handler for protocol {proto} already assigned")
             raise ValueError(f"Handler for protocol {proto} already assigned")
         self.unary_handlers[proto] = handler
         self.unary_handlers[proto] = handler
 
 
-        await self.pending_messages.put(req)
+        await self._pending_messages.put(req)
 
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         call_id = uuid.uuid4()
         call_id = uuid.uuid4()
@@ -211,7 +215,7 @@ class ControlClient:
             proto=proto,
             proto=proto,
             data=data,
             data=data,
         )
         )
-        req = p2pd_pb.PCRequest(
+        req = p2pd_pb.PersistentConnectionRequest(
             callId=call_id.bytes,
             callId=call_id.bytes,
             callUnary=call_unary_req,
             callUnary=call_unary_req,
         )
         )
@@ -219,16 +223,16 @@ class ControlClient:
         await self._ensure_persistent_conn()
         await self._ensure_persistent_conn()
 
 
         try:
         try:
-            self.pending_calls[call_id] = asyncio.Future()
-            await self.pending_messages.put(req)
-            return await self.pending_calls[call_id]
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
 
 
         except asyncio.CancelledError:
         except asyncio.CancelledError:
             asyncio.create_task(self._send_call_cancel(call_id))
             asyncio.create_task(self._send_call_cancel(call_id))
             raise
             raise
 
 
         finally:
         finally:
-            self.pending_calls.pop(call_id, None)
+            self._pending_calls.pop(call_id, None)
 
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()
         reader, writer = await self.daemon_connector.open_connection()

+ 1 - 6
hivemind/p2p/servicer.py

@@ -98,12 +98,7 @@ class ServicerBase:
                 self: StubBase, input: input_type, timeout: Optional[float] = None
                 self: StubBase, input: input_type, timeout: Optional[float] = None
             ) -> handler.response_type:
             ) -> handler.response_type:
                 return await asyncio.wait_for(
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        handler.handle_name,
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
                     timeout=timeout,
                     timeout=timeout,
                 )
                 )
 
 

+ 3 - 5
hivemind/proto/p2pd.proto

@@ -47,8 +47,7 @@ message Response {
   optional PSResponse pubsub = 7;
   optional PSResponse pubsub = 7;
 }
 }
 
 
-// Persistent connection request
-message PCRequest {
+message PersistentConnectionRequest {
   required bytes callId = 1;
   required bytes callId = 1;
 
 
   oneof message {
   oneof message {
@@ -59,8 +58,7 @@ message PCRequest {
   }
   }
 }
 }
 
 
-// Persistent connection response
-message PCResponse {
+message PersistentConnectionResponse {
   required bytes callId = 1;
   required bytes callId = 1;
 
 
   oneof message {
   oneof message {
@@ -214,4 +212,4 @@ message Cancel {
 
 
 message RPCError {
 message RPCError {
   optional string message = 1;
   optional string message = 1;
-}
+}