소스 검색

Parametrize max message size for persistent connections (#376)

Denis Mazur 4 년 전
부모
커밋
fb3f57b03c
4개의 변경된 파일59개의 추가작업 그리고 16개의 파일을 삭제
  1. 9 2
      hivemind/p2p/p2p_daemon.py
  2. 26 6
      hivemind/p2p/p2p_daemon_bindings/control.py
  3. 19 3
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py
  4. 5 5
      setup.py

+ 9 - 2
hivemind/p2p/p2p_daemon.py

@@ -15,7 +15,7 @@ from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
-from hivemind.p2p.p2p_daemon_bindings.control import P2PDaemonError, P2PHandlerError
+from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto.p2pd_pb2 import RPCError
 from hivemind.utils.asyncio import as_aiter, asingle
@@ -98,6 +98,7 @@ class P2P:
         use_relay: bool = True,
         use_relay_hop: bool = False,
         use_relay_discovery: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -168,6 +169,7 @@ class P2P:
             relayHop=use_relay_hop,
             relayHopLimit=relay_hop_limit,
             tls=tls,
+            persistentConnMaxMsgSize=persistent_conn_max_msg_size,
             **process_kwargs,
         )
 
@@ -189,7 +191,12 @@ class P2P:
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
-        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(
+            control_maddr=self._daemon_listen_maddr,
+            listen_maddr=self._client_listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
+
         await self._ping_daemon()
         return self
 

+ 26 - 6
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -26,6 +26,8 @@ SUPPORT_CONN_PROTOCOLS = (
 SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
 logger = get_logger(__name__)
 
+DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
+
 
 def parse_conn_protocol(maddr: Multiaddr) -> int:
     proto_codes = set(proto.code for proto in maddr.protocols())
@@ -84,10 +86,13 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         *,
-        _initialized_with_create=False,
+        _initialized_with_create: bool = False,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
     ) -> None:
         assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
 
+        self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
+
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
@@ -107,8 +112,14 @@ class ControlClient:
         daemon_connector: DaemonConnector,
         listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
         use_persistent_conn: bool = True,
+        persistent_conn_max_msg_size=2 << 22,
     ) -> "ControlClient":
-        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+        control = cls(
+            daemon_connector,
+            listen_maddr,
+            _initialized_with_create=True,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         if use_persistent_conn:
             await control._ensure_persistent_conn()
@@ -207,12 +218,18 @@ class ControlClient:
         except Exception as e:
             response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
 
-        await self._pending_messages.put(
-            p2pd_pb.PersistentConnectionRequest(
+        payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
+        if payload.ByteSize() <= self.persistent_conn_max_msg_size:
+            await self._pending_messages.put(payload)
+        else:
+            error_msg = p2pd_pb.PersistentConnectionRequest(
                 callId=call_id.bytes,
-                unaryResponse=response,
+                callUnaryResponse=p2pd_pb.CallUnaryResponse(
+                    error=b"response size exceeds message size limit",
+                ),
             )
-        )
+            await self._pending_messages.put(error_msg)
+
         self._handler_tasks.pop(call_id)
 
     async def _cancel_unary_call(self, call_id: UUID):
@@ -255,6 +272,9 @@ class ControlClient:
             callUnary=call_unary_req,
         )
 
+        if req.ByteSize() > self.persistent_conn_max_msg_size:
+            raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
+
         try:
             self._pending_calls[call_id] = asyncio.Future()
             await self._pending_messages.put(req)

+ 19 - 3
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,7 +10,13 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
+from hivemind.p2p.p2p_daemon_bindings.control import (
+    DEFAULT_MAX_MSG_SIZE,
+    ControlClient,
+    DaemonConnector,
+    StreamHandler,
+    TUnaryHandler,
+)
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
@@ -22,11 +28,21 @@ class Client:
         self.control = None
 
     @classmethod
-    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+    async def create(
+        cls,
+        control_maddr: Multiaddr = None,
+        listen_maddr: Multiaddr = None,
+        *,
+        persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
+    ) -> "Client":
         client = cls(_initialized_with_create=True)
 
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(
+            daemon_connector=daemon_connector,
+            listen_maddr=listen_maddr,
+            persistent_conn_max_msg_size=persistent_conn_max_msg_size,
+        )
 
         return client
 

+ 5 - 5
setup.py

@@ -14,9 +14,10 @@ from setuptools import find_packages, setup
 from setuptools.command.build_py import build_py
 from setuptools.command.develop import develop
 
-P2PD_VERSION = "v0.3.5"
-P2PD_CHECKSUM = "affea8ec63dbe2423ef7453718b5798d"
+P2PD_VERSION = "v0.3.6"
+P2PD_CHECKSUM = "627d0c3b475a29331fdfd1667e828f6d"
 LIBP2P_TAR_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz"
+P2PD_BINARY_URL = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
 
 here = os.path.abspath(os.path.dirname(__file__))
 
@@ -85,11 +86,10 @@ def download_p2p_daemon():
     binary_path = os.path.join(install_path, "p2pd")
     if not os.path.exists(binary_path) or md5(binary_path) != P2PD_CHECKSUM:
         print("Downloading Peer to Peer Daemon")
-        url = f"https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd"
-        urllib.request.urlretrieve(url, binary_path)
+        urllib.request.urlretrieve(P2PD_BINARY_URL, binary_path)
         os.chmod(binary_path, 0o777)
         if md5(binary_path) != P2PD_CHECKSUM:
-            raise RuntimeError(f"Downloaded p2pd binary from {url} does not match with md5 checksum")
+            raise RuntimeError(f"Downloaded p2pd binary from {P2PD_BINARY_URL} does not match with md5 checksum")
 
 
 class BuildPy(build_py):