|
@@ -26,6 +26,8 @@ SUPPORT_CONN_PROTOCOLS = (
|
|
|
SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
+DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
|
|
|
+
|
|
|
|
|
|
def parse_conn_protocol(maddr: Multiaddr) -> int:
|
|
|
proto_codes = set(proto.code for proto in maddr.protocols())
|
|
@@ -84,10 +86,13 @@ class ControlClient:
|
|
|
daemon_connector: DaemonConnector,
|
|
|
listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
|
|
|
*,
|
|
|
- _initialized_with_create=False,
|
|
|
+ _initialized_with_create: bool = False,
|
|
|
+ persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
|
|
|
) -> None:
|
|
|
assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
|
|
|
|
|
|
+ self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
|
|
|
+
|
|
|
self.listen_maddr = listen_maddr
|
|
|
self.daemon_connector = daemon_connector
|
|
|
self.handlers: Dict[str, StreamHandler] = {}
|
|
@@ -107,8 +112,14 @@ class ControlClient:
|
|
|
daemon_connector: DaemonConnector,
|
|
|
listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
|
|
|
use_persistent_conn: bool = True,
|
|
|
+ persistent_conn_max_msg_size=2 << 22,
|
|
|
) -> "ControlClient":
|
|
|
- control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
|
|
|
+ control = cls(
|
|
|
+ daemon_connector,
|
|
|
+ listen_maddr,
|
|
|
+ _initialized_with_create=True,
|
|
|
+ persistent_conn_max_msg_size=persistent_conn_max_msg_size,
|
|
|
+ )
|
|
|
|
|
|
if use_persistent_conn:
|
|
|
await control._ensure_persistent_conn()
|
|
@@ -207,12 +218,18 @@ class ControlClient:
|
|
|
except Exception as e:
|
|
|
response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
|
|
|
|
|
|
- await self._pending_messages.put(
|
|
|
- p2pd_pb.PersistentConnectionRequest(
|
|
|
+ payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
|
|
|
+ if payload.ByteSize() <= self.persistent_conn_max_msg_size:
|
|
|
+ await self._pending_messages.put(payload)
|
|
|
+ else:
|
|
|
+ error_msg = p2pd_pb.PersistentConnectionRequest(
|
|
|
callId=call_id.bytes,
|
|
|
- unaryResponse=response,
|
|
|
+ callUnaryResponse=p2pd_pb.CallUnaryResponse(
|
|
|
+ error=b"response size exceeds message size limit",
|
|
|
+ ),
|
|
|
)
|
|
|
- )
|
|
|
+ await self._pending_messages.put(error_msg)
|
|
|
+
|
|
|
self._handler_tasks.pop(call_id)
|
|
|
|
|
|
async def _cancel_unary_call(self, call_id: UUID):
|
|
@@ -255,6 +272,9 @@ class ControlClient:
|
|
|
callUnary=call_unary_req,
|
|
|
)
|
|
|
|
|
|
+ if req.ByteSize() > self.persistent_conn_max_msg_size:
|
|
|
+ raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
|
|
|
+
|
|
|
try:
|
|
|
self._pending_calls[call_id] = asyncio.Future()
|
|
|
await self._pending_messages.put(req)
|