浏览代码

chunked read/write proto

justheuristic 3 年之前
父节点
当前提交
a150f1dc21
共有 1 个文件被更改,包括 19 次插入12 次删除
  1. 19 12
      hivemind/p2p/p2p_daemon.py

+ 19 - 12
hivemind/p2p/p2p_daemon.py

@@ -260,32 +260,31 @@ class P2P:
         return self._daemon_listen_maddr
 
     @staticmethod
-    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
+    async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, drain: bool = True) -> None:
         writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
-        data = memoryview(data)
-        for offset in range(0, len(data), chunk_size):
-            writer.write(data[offset : offset + chunk_size])
+        writer.write(data)
+        if drain:
             await writer.drain()
 
     @staticmethod
-    async def receive_raw_data(reader: asyncio.StreamReader, *, chunk_size: int = 2 ** 16) -> bytes:
+    async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
         header = await reader.readexactly(P2P.HEADER_LEN)
         content_length = int.from_bytes(header, P2P.BYTEORDER)
-        data = bytearray(content_length)
-        for offset in range(0, content_length, chunk_size):
-            data[offset : offset + chunk_size] = await reader.readexactly(min(chunk_size, len(data) - offset))
-        return data
+        return await reader.readexactly(content_length)
 
     TInputProtobuf = TypeVar("TInputProtobuf")
     TOutputProtobuf = TypeVar("TOutputProtobuf")
 
     @staticmethod
-    async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
+    async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
         if isinstance(protobuf, RPCError):
             writer.write(P2P.ERROR_MARKER)
         else:
             writer.write(P2P.MESSAGE_MARKER)
-        await P2P.send_raw_data(protobuf.SerializeToString(), writer)
+        data = memoryview(protobuf.SerializeToString())
+        writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
+        for offset in range(0, len(data), chunk_size):
+            writer.write(data[offset : offset + chunk_size], writer, drain=True)
 
     @staticmethod
     async def receive_protobuf(
@@ -294,7 +293,15 @@ class P2P:
         msg_type = await reader.readexactly(1)
         if msg_type == P2P.MESSAGE_MARKER:
             protobuf = input_protobuf_type()
-            protobuf.ParseFromString(await P2P.receive_raw_data(reader))
+            header = await reader.readexactly(P2P.HEADER_LEN)
+            content_length = int.from_bytes(header, P2P.BYTEORDER)
+            data = bytearray(content_length)
+            offset = 0
+            while offset < len(data):
+                chunk = await P2P.receive_raw_data(reader)
+                buf[offset : offset + len(chunk)] = chunk
+                offset += len(chunk)
+            protobuf.ParseFromString(data)
             return protobuf, None
         elif msg_type == P2P.ERROR_MARKER:
             protobuf = RPCError()