Forráskód Böngészése

fix suggestions
Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>

Denis Mazur 4 éve
szülő
commit
96c59b9c1e

+ 21 - 18
hivemind/p2p/p2p_daemon.py

@@ -4,11 +4,11 @@ import secrets
 from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
-from google.protobuf.message import Message
 from importlib.resources import path
 from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
 
+from google.protobuf.message import Message
 from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
@@ -170,12 +170,6 @@ class P2P:
 
         return self
 
-    async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
-        return await self._client.add_unary_handler(handle_name, handler)
-
-    async def call_unary_handler(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
-        return await self._client.call_unary_handler(peer_id, handle_name, data)
-
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
             await asyncio.sleep(ping_delay * (2 ** try_number))
@@ -282,7 +276,7 @@ class P2P:
 
     @staticmethod
     async def receive_protobuf(
-        input_protobuf_type: type, reader: asyncio.StreamReader
+        input_protobuf_type: Message, reader: asyncio.StreamReader
     ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
@@ -303,7 +297,7 @@ class P2P:
         self,
         name: str,
         handler: Callable[[TInputStream, P2PContext], TOutputStream],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
         max_prefetch: int = 5,
     ) -> None:
         """
@@ -367,7 +361,7 @@ class P2P:
         await self._client.stream_handler(name, _handle_stream)
 
     async def _iterate_protobuf_stream_handler(
-        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: type
+        self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
     ) -> TOutputStream:
         _, reader, writer = await self._client.stream_open(peer_id, (name,))
 
@@ -399,7 +393,7 @@ class P2P:
         handler: Callable[
             [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
         ],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
         *,
         stream_input: bool = False,
         stream_output: bool = False,
@@ -410,7 +404,7 @@ class P2P:
         :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
         """
 
-        if not (stream_input or stream_output):
+        if not stream_input and not stream_output:
             await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
             return
 
@@ -439,8 +433,17 @@ class P2P:
         self,
         handle_name: str,
         handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
-        input_protobuf_type: type,
+        input_protobuf_type: Message,
     ) -> None:
+        """
+        Register a request-response (unary) handler. Unary requests and responses
+        are sent through persistent multiplexed connections to the daemon for the
+        sake of reducing the number of open files.
+        :param handle_name: name of the handler (protocol id)
+        :param handler: function handling the unary requests
+        :param input_protobuf_type: protobuf type of the request
+        """
+
         async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
             input_serialized = input_protobuf_type.FromString(request)
             context = P2PContext(
@@ -452,14 +455,14 @@ class P2P:
             response = await handler(input_serialized, context)
             return response.SerializeToString()
 
-        await self.add_unary_handler(handle_name, _unary_handler)
+        await self._client.add_unary_handler(handle_name, _unary_handler)
 
     async def call_protobuf_handler(
         self,
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> Awaitable[TOutputProtobuf]:
 
         if not isinstance(input, AsyncIterableABC):
@@ -479,10 +482,10 @@ class P2P:
         peer_id: PeerID,
         handle_name: str,
         input: TInputProtobuf,
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> Awaitable[TOutputProtobuf]:
         serialized_input = input.SerializeToString()
-        response = await self.call_unary_handler(peer_id, handle_name, serialized_input)
+        response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
         return output_protobuf_type().FromString(response)
 
     def iterate_protobuf_handler(
@@ -490,7 +493,7 @@ class P2P:
         peer_id: PeerID,
         name: str,
         input: Union[TInputProtobuf, TInputStream],
-        output_protobuf_type: type,
+        output_protobuf_type: Message,
     ) -> TOutputStream:
         requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
         return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)

+ 42 - 38
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -80,14 +80,13 @@ class ControlClient:
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
 
-        # persistent connection readers & writers
-        self._pers_conn_open: bool = False
+        self._is_persistent_conn_open: bool = False
         self.unary_handlers: Dict[str, TUnaryHandler] = {}
 
         self._ensure_conn_lock = asyncio.Lock()
-        self.pending_messages: asyncio.Queue[p2pd_pb.PCRequest] = asyncio.Queue()
-        self.pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
-        self.handler_tasks: Dict[CallID, asyncio.Task] = {}
+        self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
+        self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
+        self._handler_tasks: Dict[CallID, asyncio.Task] = {}
 
         self._read_task: Optional[asyncio.Task] = None
         self._write_task: Optional[asyncio.Task] = None
@@ -115,38 +114,38 @@ class ControlClient:
                 self._write_task.cancel()
 
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
-        with closing(reader):
-            while True:
-                resp = p2pd_pb.PCResponse()
-                await read_pbmsg_safe(reader, resp)
+        while True:
+            resp = p2pd_pb.PersistentConnectionResponse()
+            await read_pbmsg_safe(reader, resp)
 
-                call_id = uuid.UUID(bytes=resp.callId)
+            call_id = uuid.UUID(bytes=resp.callId)
 
-                if resp.HasField("callUnaryResponse"):
-                    if call_id in self.pending_calls and resp.callUnaryResponse.HasField("response"):
-                        self.pending_calls[call_id].set_result(resp.callUnaryResponse.response)
-                    elif call_id in self.pending_calls and resp.callUnaryResponse.HasField("error"):
-                        remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode())
-                        self.pending_calls[call_id].set_exception(remote_exc)
-                    else:
-                        logger.debug(f"received unexpected unary call")
+            if resp.HasField("callUnaryResponse"):
+                if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
+                    self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
+                elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
+                    remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
+                    self._pending_calls[call_id].set_exception(remote_exc)
+                else:
+                    logger.debug("received unexpected unary call")
 
-                elif resp.HasField("requestHandling"):
-                    handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
-                    self.handler_tasks[call_id] = handler_task
+            elif resp.HasField("requestHandling"):
+                handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
+                self._handler_tasks[call_id] = handler_task
 
-                elif call_id in self.handler_tasks and resp.HasField("cancel"):
-                    self.handler_tasks[call_id].cancel()
+            elif call_id in self._handler_tasks and resp.HasField("cancel"):
+                self._handler_tasks[call_id].cancel()
 
     async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
         with closing(writer):
             while True:
-                msg = await self.pending_messages.get()
+                msg = await self._pending_messages.get()
                 await write_pbmsg(writer, msg)
 
     async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
         if request.proto not in self.unary_handlers:
             logger.warning(f"Protocol {request.proto} not supported")
+            return
 
         try:
             remote_id = PeerID(request.peer)
@@ -156,8 +155,13 @@ class ControlClient:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
-        await self.pending_messages.put(p2pd_pb.PCRequest(callId=call_id.bytes, unaryResponse=response))
-        self.handler_tasks.pop(call_id)
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
+                callId=call_id.bytes,
+                unaryResponse=response,
+            )
+        )
+        self._handler_tasks.pop(call_id)
 
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
@@ -172,23 +176,23 @@ class ControlClient:
         await handler(stream_info, reader, writer)
 
     async def _send_call_cancel(self, call_id: uuid.UUID):
-        await self.pending_messages.put(
-            p2pd_pb.PCRequest(
+        await self._pending_messages.put(
+            p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
                 cancel=p2pd_pb.Cancel(),
             ),
         )
 
     async def _ensure_persistent_conn(self):
-        if not self._pers_conn_open:
+        if not self._is_persistent_conn_open:
             async with self._ensure_conn_lock:
-                if not self._pers_conn_open:
+                if not self._is_persistent_conn_open:
                     reader, writer = await self.daemon_connector.open_persistent_connection()
 
                     self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
                     self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
-                    self._pers_conn_open = True
+                    self._is_persistent_conn_open = True
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
         await self._ensure_persistent_conn()
@@ -196,13 +200,13 @@ class ControlClient:
         call_id = uuid.uuid4()
 
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
-        req = p2pd_pb.PCRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
+        req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
 
         if self.unary_handlers.get(proto):
             raise ValueError(f"Handler for protocol {proto} already assigned")
         self.unary_handlers[proto] = handler
 
-        await self.pending_messages.put(req)
+        await self._pending_messages.put(req)
 
     async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
         call_id = uuid.uuid4()
@@ -211,7 +215,7 @@ class ControlClient:
             proto=proto,
             data=data,
         )
-        req = p2pd_pb.PCRequest(
+        req = p2pd_pb.PersistentConnectionRequest(
             callId=call_id.bytes,
             callUnary=call_unary_req,
         )
@@ -219,16 +223,16 @@ class ControlClient:
         await self._ensure_persistent_conn()
 
         try:
-            self.pending_calls[call_id] = asyncio.Future()
-            await self.pending_messages.put(req)
-            return await self.pending_calls[call_id]
+            self._pending_calls[call_id] = asyncio.Future()
+            await self._pending_messages.put(req)
+            return await self._pending_calls[call_id]
 
         except asyncio.CancelledError:
             asyncio.create_task(self._send_call_cancel(call_id))
             raise
 
         finally:
-            self.pending_calls.pop(call_id, None)
+            self._pending_calls.pop(call_id, None)
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()

+ 1 - 6
hivemind/p2p/servicer.py

@@ -98,12 +98,7 @@ class ServicerBase:
                 self: StubBase, input: input_type, timeout: Optional[float] = None
             ) -> handler.response_type:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(
-                        self._peer,
-                        handler.handle_name,
-                        input,
-                        handler.response_type,
-                    ),
+                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
                     timeout=timeout,
                 )
 

+ 3 - 5
hivemind/proto/p2pd.proto

@@ -47,8 +47,7 @@ message Response {
   optional PSResponse pubsub = 7;
 }
 
-// Persistent connection request
-message PCRequest {
+message PersistentConnectionRequest {
   required bytes callId = 1;
 
   oneof message {
@@ -59,8 +58,7 @@ message PCRequest {
   }
 }
 
-// Persistent connection response
-message PCResponse {
+message PersistentConnectionResponse {
   required bytes callId = 1;
 
   oneof message {
@@ -214,4 +212,4 @@ message Cancel {
 
 message RPCError {
   optional string message = 1;
-}
+}