control.py 15 KB


  1. """
  2. Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
  3. Licence: MIT
  4. Author: Kevin Mai-Husan Chia
  5. """
  6. import asyncio
  7. from contextlib import asynccontextmanager, closing
  8. from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
  9. from uuid import UUID, uuid4
  10. from multiaddr import Multiaddr, protocols
  11. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  12. from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
  13. from hivemind.proto import p2pd_pb2 as p2pd_pb
  14. from hivemind.utils.logging import get_logger
  15. StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
  16. SUPPORT_CONN_PROTOCOLS = (
  17. protocols.P_IP4,
  18. # protocols.P_IP6,
  19. protocols.P_UNIX,
  20. )
  21. SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
  22. logger = get_logger(__name__)
  23. DEFAULT_MAX_MSG_SIZE = 4 * 1024 ** 2
  24. def parse_conn_protocol(maddr: Multiaddr) -> int:
  25. proto_codes = set(proto.code for proto in maddr.protocols())
  26. proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
  27. if len(proto_cand) != 1:
  28. raise ValueError(
  29. f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}" f", maddr={maddr}"
  30. )
  31. return tuple(proto_cand)[0]
  32. class DaemonConnector:
  33. DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
  34. def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
  35. self.control_maddr = control_maddr
  36. self.proto_code = parse_conn_protocol(self.control_maddr)
  37. async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
  38. if self.proto_code == protocols.P_UNIX:
  39. control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
  40. return await asyncio.open_unix_connection(control_path)
  41. elif self.proto_code == protocols.P_IP4:
  42. host = self.control_maddr.value_for_protocol(protocols.P_IP4)
  43. port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
  44. return await asyncio.open_connection(host, port)
  45. else:
  46. raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
  47. async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
  48. """
  49. Open connection to daemon and upgrade it to a persistent one
  50. """
  51. reader, writer = await self.open_connection()
  52. req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
  53. await write_pbmsg(writer, req)
  54. response = p2pd_pb.Response()
  55. await read_pbmsg_safe(reader, response)
  56. if response.type == "ERROR":
  57. raise P2PDaemonError(response.error.msg)
  58. return reader, writer
  59. TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
  60. CallID = UUID
  61. class ControlClient:
  62. DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
  63. def __init__(
  64. self,
  65. daemon_connector: DaemonConnector,
  66. listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
  67. *,
  68. _initialized_with_create: bool = False,
  69. persistent_conn_max_msg_size: int = DEFAULT_MAX_MSG_SIZE,
  70. ) -> None:
  71. assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
  72. self.persistent_conn_max_msg_size = persistent_conn_max_msg_size
  73. self.listen_maddr = listen_maddr
  74. self.daemon_connector = daemon_connector
  75. self.handlers: Dict[str, StreamHandler] = {}
  76. self.unary_handlers: Dict[str, TUnaryHandler] = {}
  77. self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
  78. self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
  79. self._handler_tasks: Dict[CallID, asyncio.Task] = {}
  80. self._read_task: Optional[asyncio.Task] = None
  81. self._write_task: Optional[asyncio.Task] = None
  82. @classmethod
  83. async def create(
  84. cls,
  85. daemon_connector: DaemonConnector,
  86. listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
  87. use_persistent_conn: bool = True,
  88. persistent_conn_max_msg_size=2 << 22,
  89. ) -> "ControlClient":
  90. control = cls(
  91. daemon_connector,
  92. listen_maddr,
  93. _initialized_with_create=True,
  94. persistent_conn_max_msg_size=persistent_conn_max_msg_size,
  95. )
  96. if use_persistent_conn:
  97. await control._ensure_persistent_conn()
  98. return control
  99. def close(self) -> None:
  100. if self._read_task is not None:
  101. self._read_task.cancel()
  102. if self._write_task is not None:
  103. self._write_task.cancel()
  104. def __del__(self):
  105. self.close()
  106. async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
  107. pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
  108. await read_pbmsg_safe(reader, pb_stream_info)
  109. stream_info = StreamInfo.from_protobuf(pb_stream_info)
  110. try:
  111. handler = self.handlers[stream_info.proto]
  112. except KeyError as e:
  113. # should never enter here... daemon should reject the stream for us.
  114. writer.close()
  115. raise DispatchFailure(e)
  116. await handler(stream_info, reader, writer)
  117. @asynccontextmanager
  118. async def listen(self) -> AsyncIterator["ControlClient"]:
  119. proto_code = parse_conn_protocol(self.listen_maddr)
  120. if proto_code == protocols.P_UNIX:
  121. listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
  122. server = await asyncio.start_unix_server(self._handler, path=listen_path)
  123. elif proto_code == protocols.P_IP4:
  124. host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
  125. port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
  126. server = await asyncio.start_server(self._handler, port=port, host=host)
  127. else:
  128. raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
  129. async with server:
  130. yield self
  131. async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
  132. while True:
  133. resp = p2pd_pb.PersistentConnectionResponse()
  134. try:
  135. await read_pbmsg_safe(reader, resp)
  136. except asyncio.IncompleteReadError:
  137. break
  138. call_id = UUID(bytes=resp.callId)
  139. if resp.HasField("callUnaryResponse"):
  140. if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
  141. self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
  142. elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
  143. remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
  144. self._pending_calls[call_id].set_exception(remote_exc)
  145. else:
  146. logger.debug(f"Received unexpected unary call: {resp}")
  147. elif resp.HasField("requestHandling"):
  148. handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
  149. self._handler_tasks[call_id] = handler_task
  150. elif call_id in self._handler_tasks and resp.HasField("cancel"):
  151. self._handler_tasks[call_id].cancel()
  152. elif call_id in self._pending_calls and resp.HasField("daemonError"):
  153. daemon_exc = P2PDaemonError(resp.daemonError.message)
  154. self._pending_calls[call_id].set_exception(daemon_exc)
  155. elif call_id in self._pending_calls:
  156. self._pending_calls[call_id].set_result(None)
  157. else:
  158. logger.debug(f"Received unexpected response from daemon: {resp}")
  159. async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
  160. with closing(writer):
  161. while True:
  162. msg = await self._pending_messages.get()
  163. await write_pbmsg(writer, msg)
  164. async def _handle_persistent_request(self, call_id: UUID, request: p2pd_pb.CallUnaryRequest):
  165. if request.proto not in self.unary_handlers:
  166. logger.warning(f"Protocol {request.proto} not supported")
  167. return
  168. try:
  169. remote_id = PeerID(request.peer)
  170. response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
  171. response = p2pd_pb.CallUnaryResponse(response=response_payload)
  172. except Exception as e:
  173. response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
  174. payload = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, unaryResponse=response)
  175. if payload.ByteSize() <= self.persistent_conn_max_msg_size:
  176. await self._pending_messages.put(payload)
  177. else:
  178. error_msg = p2pd_pb.PersistentConnectionRequest(
  179. callId=call_id.bytes,
  180. callUnaryResponse=p2pd_pb.CallUnaryResponse(
  181. error=b"response size exceeds message size limit",
  182. ),
  183. )
  184. await self._pending_messages.put(error_msg)
  185. self._handler_tasks.pop(call_id)
  186. async def _cancel_unary_call(self, call_id: UUID):
  187. await self._pending_messages.put(
  188. p2pd_pb.PersistentConnectionRequest(
  189. callId=call_id.bytes,
  190. cancel=p2pd_pb.Cancel(),
  191. ),
  192. )
  193. async def _ensure_persistent_conn(self):
  194. reader, writer = await self.daemon_connector.open_persistent_connection()
  195. self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
  196. self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
  197. async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
  198. call_id = uuid4()
  199. add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
  200. req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
  201. if self.unary_handlers.get(proto):
  202. raise P2PDaemonError(f"Handler for protocol {proto} already registered")
  203. self.unary_handlers[proto] = handler
  204. self._pending_calls[call_id] = asyncio.Future()
  205. await self._pending_messages.put(req)
  206. await self._pending_calls[call_id]
  207. async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
  208. call_id = uuid4()
  209. call_unary_req = p2pd_pb.CallUnaryRequest(
  210. peer=peer_id.to_bytes(),
  211. proto=proto,
  212. data=data,
  213. )
  214. req = p2pd_pb.PersistentConnectionRequest(
  215. callId=call_id.bytes,
  216. callUnary=call_unary_req,
  217. )
  218. if req.ByteSize() > self.persistent_conn_max_msg_size:
  219. raise P2PDaemonError(f"Message size exceeds set limit {self.persistent_conn_max_msg_size}")
  220. try:
  221. self._pending_calls[call_id] = asyncio.Future()
  222. await self._pending_messages.put(req)
  223. return await self._pending_calls[call_id]
  224. except asyncio.CancelledError:
  225. await self._cancel_unary_call(call_id)
  226. raise
  227. finally:
  228. self._pending_calls.pop(call_id, None)
  229. async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
  230. reader, writer = await self.daemon_connector.open_connection()
  231. req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
  232. await write_pbmsg(writer, req)
  233. resp = p2pd_pb.Response() # type: ignore
  234. await read_pbmsg_safe(reader, resp)
  235. writer.close()
  236. raise_if_failed(resp)
  237. peer_id_bytes = resp.identify.id
  238. maddrs_bytes = resp.identify.addrs
  239. maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
  240. peer_id = PeerID(peer_id_bytes)
  241. return peer_id, maddrs
  242. async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
  243. reader, writer = await self.daemon_connector.open_connection()
  244. maddrs_bytes = [i.to_bytes() for i in maddrs]
  245. connect_req = p2pd_pb.ConnectRequest(peer=peer_id.to_bytes(), addrs=maddrs_bytes)
  246. req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
  247. await write_pbmsg(writer, req)
  248. resp = p2pd_pb.Response() # type: ignore
  249. await read_pbmsg_safe(reader, resp)
  250. writer.close()
  251. raise_if_failed(resp)
  252. async def list_peers(self) -> Tuple[PeerInfo, ...]:
  253. req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
  254. reader, writer = await self.daemon_connector.open_connection()
  255. await write_pbmsg(writer, req)
  256. resp = p2pd_pb.Response() # type: ignore
  257. await read_pbmsg_safe(reader, resp)
  258. writer.close()
  259. raise_if_failed(resp)
  260. peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
  261. return peers
  262. async def disconnect(self, peer_id: PeerID) -> None:
  263. disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
  264. req = p2pd_pb.Request(type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req)
  265. reader, writer = await self.daemon_connector.open_connection()
  266. await write_pbmsg(writer, req)
  267. resp = p2pd_pb.Response() # type: ignore
  268. await read_pbmsg_safe(reader, resp)
  269. writer.close()
  270. raise_if_failed(resp)
  271. async def stream_open(
  272. self, peer_id: PeerID, protocols: Sequence[str]
  273. ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
  274. reader, writer = await self.daemon_connector.open_connection()
  275. stream_open_req = p2pd_pb.StreamOpenRequest(peer=peer_id.to_bytes(), proto=list(protocols))
  276. req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req)
  277. await write_pbmsg(writer, req)
  278. resp = p2pd_pb.Response() # type: ignore
  279. await read_pbmsg_safe(reader, resp)
  280. raise_if_failed(resp)
  281. pb_stream_info = resp.streamInfo
  282. stream_info = StreamInfo.from_protobuf(pb_stream_info)
  283. return stream_info, reader, writer
  284. async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
  285. reader, writer = await self.daemon_connector.open_connection()
  286. listen_path_maddr_bytes = self.listen_maddr.to_bytes()
  287. stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
  288. req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
  289. await write_pbmsg(writer, req)
  290. resp = p2pd_pb.Response() # type: ignore
  291. await read_pbmsg_safe(reader, resp)
  292. writer.close()
  293. raise_if_failed(resp)
  294. # if success, add the handler to the dict
  295. self.handlers[proto] = handler_cb
  296. class P2PHandlerError(Exception):
  297. """
  298. Raised if remote handled a request with an exception
  299. """
  300. class P2PDaemonError(Exception):
  301. """
  302. Raised if daemon failed to handle request
  303. """