connection.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from contextlib import AbstractContextManager
  2. from socket import socket
  3. from typing import Tuple
  4. class Connection(AbstractContextManager):
  5. header_size = 4 # number of characters in all headers
  6. payload_length_size = 8 # number of bytes used to encode payload length
  7. __slots__ = ('conn', 'addr')
  8. def __init__(self, conn: socket, addr: Tuple[str, int]):
  9. self.conn, self.addr = conn, addr
  10. @staticmethod
  11. def create(host: str, port: int):
  12. sock = socket()
  13. addr = (host, port)
  14. sock.connect(addr)
  15. return Connection(sock, addr)
  16. def send_raw(self, header: str, content: bytes):
  17. self.conn.send(header.encode())
  18. self.conn.send(len(content).to_bytes(self.payload_length_size, byteorder='big'))
  19. total_sent = 0
  20. while total_sent < len(content):
  21. sent = self.conn.send(content[total_sent:])
  22. if sent == 0:
  23. raise RuntimeError("socket connection broken")
  24. total_sent = total_sent + sent
  25. def recv_header(self) -> str:
  26. return self.conn.recv(self.header_size).decode()
  27. def recv_raw(self, max_package: int = 2048) -> bytes:
  28. length = int.from_bytes(self.conn.recv(self.payload_length_size), byteorder='big')
  29. chunks = []
  30. bytes_recd = 0
  31. while bytes_recd < length:
  32. chunk = self.conn.recv(min(length - bytes_recd, max_package))
  33. if chunk == b'':
  34. raise RuntimeError("socket connection broken")
  35. chunks.append(chunk)
  36. bytes_recd = bytes_recd + len(chunk)
  37. ret = b''.join(chunks)
  38. assert len(ret) == length
  39. return ret
  40. def recv_message(self) -> Tuple[str, bytes]:
  41. return self.recv_header(), self.recv_raw()
  42. def __exit__(self, *exc_info):
  43. self.conn.close()
  44. def find_open_port():
  45. try:
  46. sock = socket()
  47. sock.bind(('', 0))
  48. return sock.getsockname()[1]
  49. except:
  50. raise ValueError("Could not find open port")