123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- from contextlib import AbstractContextManager
- from socket import socket
- from typing import Tuple
- class Connection(AbstractContextManager):
- header_size = 4 # number of characters in all headers
- payload_length_size = 8 # number of bytes used to encode payload length
- __slots__ = ('conn', 'addr')
- def __init__(self, conn: socket, addr: Tuple[str, int]):
- self.conn, self.addr = conn, addr
- @staticmethod
- def create(host: str, port: int):
- sock = socket()
- addr = (host, port)
- sock.connect(addr)
- return Connection(sock, addr)
- def send_raw(self, header: str, content: bytes):
- self.conn.send(header.encode())
- self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big'))
- total_sent = 0
- while total_sent < len(content):
- sent = self.conn.send(content[total_sent:])
- if sent == 0:
- raise RuntimeError("socket connection broken")
- total_sent = total_sent + sent
- def recv_header(self) -> str:
- return self.conn.recv(self.header_size).decode()
- def recv_raw(self, max_package: int = 2048) -> bytes:
- length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big')
- chunks = []
- bytes_recd = 0
- while bytes_recd < length:
- chunk = self.conn.recv(min(length - bytes_recd, max_package))
- if chunk == b'':
- raise RuntimeError("socket connection broken")
- chunks.append(chunk)
- bytes_recd = bytes_recd + len(chunk)
- ret = b''.join(chunks)
- assert len(ret) == length
- return ret
- def recv_message(self) -> Tuple[str, bytes]:
- return self.recv_header(), self.recv_raw()
- def __exit__(self, *exc_info):
- self.conn.close()
- def find_open_port():
- try:
- sock = socket()
- sock.bind(('', 0))
- return sock.getsockname()[1]
- except:
- raise ValueError("Could not find open port")
|