utils.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. """
  2. Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
  3. Licence: MIT
  4. Author: Kevin Mai-Husan Chia
  5. """
  6. import asyncio
  7. from google.protobuf.message import Message as PBMessage
  8. from hivemind.proto import p2pd_pb2 as p2pd_pb
  9. DEFAULT_MAX_BITS: int = 64
  10. class ControlFailure(Exception):
  11. pass
  12. class DispatchFailure(Exception):
  13. pass
  14. async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None:
  15. max_int = 1 << max_bits
  16. if integer < 0:
  17. raise ValueError(f"negative integer: {integer}")
  18. if integer >= max_int:
  19. raise ValueError(f"integer too large: {integer}")
  20. while True:
  21. value = integer & 0x7F
  22. integer >>= 7
  23. if integer != 0:
  24. value |= 0x80
  25. byte = value.to_bytes(1, "big")
  26. stream.write(byte)
  27. if integer == 0:
  28. break
  29. async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int:
  30. max_int = 1 << max_bits
  31. iteration = 0
  32. result = 0
  33. has_next = True
  34. while has_next:
  35. data = await stream.readexactly(1)
  36. c = data[0]
  37. value = c & 0x7F
  38. result |= value << (iteration * 7)
  39. has_next = (c & 0x80) != 0
  40. iteration += 1
  41. if result >= max_int:
  42. raise ValueError(f"Varint overflowed: {result}")
  43. return result
  44. def raise_if_failed(response: p2pd_pb.Response) -> None:
  45. if response.type == p2pd_pb.Response.ERROR:
  46. raise ControlFailure(f"Connect failed. msg={response.error.msg}")
  47. async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
  48. size = pbmsg.ByteSize()
  49. await write_unsigned_varint(stream, size)
  50. msg_bytes: bytes = pbmsg.SerializeToString()
  51. stream.write(msg_bytes)
  52. async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
  53. len_msg_bytes = await read_unsigned_varint(stream)
  54. msg_bytes = await stream.readexactly(len_msg_bytes)
  55. pbmsg.ParseFromString(msg_bytes)