control.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """
  2. Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
  3. Licence: MIT
  4. Author: Kevin Mai-Husan Chia
  5. """
  6. import asyncio
  7. import uuid
  8. from contextlib import asynccontextmanager
  9. from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
  10. from multiaddr import Multiaddr, protocols
  11. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  12. from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
  13. from hivemind.proto import p2pd_pb2 as p2pd_pb
  14. from hivemind.utils.logging import get_logger
  15. StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
  16. SUPPORT_CONN_PROTOCOLS = (
  17. protocols.P_IP4,
  18. # protocols.P_IP6,
  19. protocols.P_UNIX,
  20. )
  21. SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
  22. logger = get_logger(__name__)
  23. def parse_conn_protocol(maddr: Multiaddr) -> int:
  24. proto_codes = set(proto.code for proto in maddr.protocols())
  25. proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
  26. if len(proto_cand) != 1:
  27. raise ValueError(
  28. f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}" f", maddr={maddr}"
  29. )
  30. return tuple(proto_cand)[0]
  31. class DaemonConnector:
  32. DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
  33. def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
  34. self.control_maddr = control_maddr
  35. self.proto_code = parse_conn_protocol(self.control_maddr)
  36. async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
  37. if self.proto_code == protocols.P_UNIX:
  38. control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
  39. return await asyncio.open_unix_connection(control_path)
  40. elif self.proto_code == protocols.P_IP4:
  41. host = self.control_maddr.value_for_protocol(protocols.P_IP4)
  42. port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
  43. return await asyncio.open_connection(host, port)
  44. else:
  45. raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
  46. async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
  47. """
  48. Open connection to daemon and upgrade it to a persistent one
  49. """
  50. reader, writer = await self.open_connection()
  51. req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
  52. await write_pbmsg(writer, req)
  53. return reader, writer
  54. TUnaryHandler = Callable[[bytes], bytes]
  55. CallID = uuid.UUID
  56. class ControlClient:
  57. DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
  58. def __init__(
  59. self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
  60. ) -> None:
  61. self.listen_maddr = listen_maddr
  62. self.daemon_connector = daemon_connector
  63. self.handlers: Dict[str, StreamHandler] = {}
  64. # persistent connection readers & writers
  65. self._pers_conn_open: bool = False
  66. self.unary_handlers: Dict[str, TUnaryHandler] = {}
  67. self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
  68. self.pending_calls: Dict[CallID, asyncio.Future] = {}
  69. @asynccontextmanager
  70. async def listen(self) -> AsyncIterator["ControlClient"]:
  71. proto_code = parse_conn_protocol(self.listen_maddr)
  72. if proto_code == protocols.P_UNIX:
  73. listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
  74. server = await asyncio.start_unix_server(self._handler, path=listen_path)
  75. elif proto_code == protocols.P_IP4:
  76. host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
  77. port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
  78. server = await asyncio.start_server(self._handler, port=port, host=host)
  79. else:
  80. raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
  81. async with server:
  82. yield self
  83. async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
  84. while True:
  85. resp: p2pd_pb.Response = p2pd_pb.Response() # type: ignore
  86. await read_pbmsg_safe(reader, resp)
  87. if resp.HasField("callUnaryResponse"):
  88. call_id = uuid.UUID(bytes=resp.callUnaryResponse.callId)
  89. if call_id in self.pending_calls and resp.callUnaryResponse.HasField("result"):
  90. self.pending_calls[call_id].set_result(resp.callUnaryResponse.result)
  91. elif call_id in self.pending_calls and resp.callUnaryResponse.HasField("error"):
  92. remote_exc = RemoteException(str(resp.callUnaryResponse.error))
  93. self.pending_calls[call_id].set_exception(remote_exc)
  94. else:
  95. logger.debug(f"received unexpected unary call")
  96. elif resp.HasField("requestHandling"):
  97. asyncio.create_task(self._handle_persistent_request(resp.requestHandling))
  98. pass
  99. async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
  100. while True:
  101. msg = await self.pending_messages.get()
  102. await write_pbmsg(writer, msg)
  103. async def _handle_persistent_request(self, request):
  104. assert request.proto in self.unary_handlers
  105. try:
  106. response_payload: bytes = self.unary_handlers[request.proto](request.data)
  107. response = p2pd_pb.CallUnaryResponse(callId=request.callId, result=response_payload)
  108. except Exception as e:
  109. response = p2pd_pb.CallUnaryResponse(callId=request.callId, error=repr(e))
  110. await self.pending_messages.put(
  111. p2pd_pb.Request(type=p2pd_pb.Request.SEND_RESPONSE_TO_REMOTE, sendResponseToRemote=response)
  112. )
  113. async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
  114. pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
  115. await read_pbmsg_safe(reader, pb_stream_info)
  116. stream_info = StreamInfo.from_protobuf(pb_stream_info)
  117. try:
  118. handler = self.handlers[stream_info.proto]
  119. except KeyError as e:
  120. # should never enter here... daemon should reject the stream for us.
  121. writer.close()
  122. raise DispatchFailure(e)
  123. await handler(stream_info, reader, writer)
  124. async def _ensure_persistent_conn(self):
  125. if not self._pers_conn_open:
  126. reader, writer = await self.daemon_connector.open_persistent_connection()
  127. asyncio.create_task(self._read_from_persistent_conn(reader))
  128. asyncio.create_task(self._write_to_persistent_conn(writer))
  129. self._pers_conn_open = True # TODO FIXME
  130. async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
  131. await self._ensure_persistent_conn()
  132. add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
  133. req = p2pd_pb.Request(
  134. type=p2pd_pb.Request.ADD_UNARY_HANDLER,
  135. addUnaryHandler=add_unary_handler_req,
  136. )
  137. await self.pending_messages.put(req)
  138. if self.unary_handlers.get(proto):
  139. raise ValueError(f"Handler for protocol {proto} already assigned")
  140. self.unary_handlers[proto] = handler
  141. async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
  142. call_id = uuid.uuid4()
  143. call_unary_req = p2pd_pb.CallUnaryRequest(
  144. peer=peer_id.to_bytes(),
  145. proto=proto,
  146. data=data,
  147. callId=call_id.bytes,
  148. )
  149. req = p2pd_pb.Request(
  150. type=p2pd_pb.Request.CALL_UNARY,
  151. callUnary=call_unary_req,
  152. )
  153. await self._ensure_persistent_conn()
  154. try:
  155. self.pending_calls[call_id] = asyncio.Future()
  156. await self.pending_messages.put(req)
  157. return await self.pending_calls[call_id]
  158. finally:
  159. await self.pending_calls.pop(call_id)
  160. async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
  161. reader, writer = await self.daemon_connector.open_connection()
  162. req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
  163. await write_pbmsg(writer, req)
  164. resp = p2pd_pb.Response() # type: ignore
  165. await read_pbmsg_safe(reader, resp)
  166. writer.close()
  167. raise_if_failed(resp)
  168. peer_id_bytes = resp.identify.id
  169. maddrs_bytes = resp.identify.addrs
  170. maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
  171. peer_id = PeerID(peer_id_bytes)
  172. return peer_id, maddrs
  173. async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
  174. reader, writer = await self.daemon_connector.open_connection()
  175. maddrs_bytes = [i.to_bytes() for i in maddrs]
  176. connect_req = p2pd_pb.ConnectRequest(peer=peer_id.to_bytes(), addrs=maddrs_bytes)
  177. req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
  178. await write_pbmsg(writer, req)
  179. resp = p2pd_pb.Response() # type: ignore
  180. await read_pbmsg_safe(reader, resp)
  181. writer.close()
  182. raise_if_failed(resp)
  183. async def list_peers(self) -> Tuple[PeerInfo, ...]:
  184. req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
  185. reader, writer = await self.daemon_connector.open_connection()
  186. await write_pbmsg(writer, req)
  187. resp = p2pd_pb.Response() # type: ignore
  188. await read_pbmsg_safe(reader, resp)
  189. writer.close()
  190. raise_if_failed(resp)
  191. peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
  192. return peers
  193. async def disconnect(self, peer_id: PeerID) -> None:
  194. disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
  195. req = p2pd_pb.Request(type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req)
  196. reader, writer = await self.daemon_connector.open_connection()
  197. await write_pbmsg(writer, req)
  198. resp = p2pd_pb.Response() # type: ignore
  199. await read_pbmsg_safe(reader, resp)
  200. writer.close()
  201. raise_if_failed(resp)
  202. async def stream_open(
  203. self, peer_id: PeerID, protocols: Sequence[str]
  204. ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
  205. reader, writer = await self.daemon_connector.open_connection()
  206. stream_open_req = p2pd_pb.StreamOpenRequest(peer=peer_id.to_bytes(), proto=list(protocols))
  207. req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req)
  208. await write_pbmsg(writer, req)
  209. resp = p2pd_pb.Response() # type: ignore
  210. await read_pbmsg_safe(reader, resp)
  211. raise_if_failed(resp)
  212. pb_stream_info = resp.streamInfo
  213. stream_info = StreamInfo.from_protobuf(pb_stream_info)
  214. return stream_info, reader, writer
  215. async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
  216. reader, writer = await self.daemon_connector.open_connection()
  217. listen_path_maddr_bytes = self.listen_maddr.to_bytes()
  218. stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
  219. req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
  220. await write_pbmsg(writer, req)
  221. resp = p2pd_pb.Response() # type: ignore
  222. await read_pbmsg_safe(reader, resp)
  223. writer.close()
  224. raise_if_failed(resp)
  225. # if success, add the handler to the dict
  226. self.handlers[proto] = handler_cb
  227. class RemoteException(Exception):
  228. """
  229. Raised if remote handled a request with an exception
  230. """