p2p_daemon.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. import asyncio
  2. import os
  3. import secrets
  4. from contextlib import suppress
  5. from dataclasses import dataclass
  6. from importlib.resources import path
  7. from subprocess import Popen
  8. from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
  9. import google.protobuf
  10. from multiaddr import Multiaddr
  11. import hivemind.hivemind_cli as cli
  12. import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
  13. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  14. from hivemind.proto import p2pd_pb2
  15. from hivemind.utils import MSGPackSerializer
  16. from hivemind.utils.logging import get_logger
  17. logger = get_logger(__name__)
  18. P2PD_FILENAME = 'p2pd'
  19. @dataclass(frozen=True)
  20. class P2PContext(object):
  21. handle_name: str
  22. local_id: PeerID
  23. remote_id: PeerID = None
  24. remote_maddr: Multiaddr = None
  25. class P2P:
  26. """
  27. This class is responsible for establishing peer-to-peer connections through NAT and/or firewalls.
  28. It creates and manages a libp2p daemon (https://libp2p.io) in a background process,
  29. then terminates it when P2P is shut down. In order to communicate, a P2P instance should
  30. either use one or more initial_peers that will connect it to the rest of the swarm or
  31. use the public IPFS network (https://ipfs.io).
  32. For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
  33. - `P2P.add_unary_handler` accepts a protobuf message and returns another protobuf
  34. - `P2P.add_stream_handler` transfers raw data using bi-directional streaming interface
  35. To access these handlers, a P2P instance can `P2P.call_unary_handler`/`P2P.call_stream_handler`,
  36. using the recipient's unique `P2P.id` and the name of the corresponding handler.
  37. """
  38. HEADER_LEN = 8
  39. BYTEORDER = 'big'
  40. PB_HEADER_LEN = 1
  41. RESULT_MESSAGE = b'\x00'
  42. ERROR_MESSAGE = b'\x01'
  43. DHT_MODE_MAPPING = {
  44. 'dht': {'dht': 1},
  45. 'dht_server': {'dhtServer': 1},
  46. 'dht_client': {'dhtClient': 1},
  47. }
  48. FORCE_REACHABILITY_MAPPING = {
  49. 'public': {'forceReachabilityPublic': 1},
  50. 'private': {'forceReachabilityPrivate': 1},
  51. }
  52. _UNIX_SOCKET_PREFIX = '/unix/tmp/hivemind-'
  53. def __init__(self):
  54. self.id = None
  55. self._child = None
  56. self._alive = False
  57. self._listen_task = None
  58. self._server_stopped = asyncio.Event()
  59. @classmethod
  60. async def create(cls,
  61. initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
  62. use_ipfs: bool = False,
  63. host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ('/ip4/127.0.0.1/tcp/0',),
  64. announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
  65. quic: bool = True, tls: bool = True, conn_manager: bool = True,
  66. dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
  67. nat_port_map: bool = True, auto_nat: bool = True,
  68. use_relay: bool = True, use_relay_hop: bool = False,
  69. use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0,
  70. quiet: bool = True,
  71. ping_n_attempts: int = 5, ping_delay: float = 0.4) -> 'P2P':
  72. """
  73. Start a new p2pd process and connect to it.
  74. :param initial_peers: List of bootstrap peers
  75. :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
  76. :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
  77. :param announce_maddrs: Visible multiaddrs that the peer will announce
  78. for external connections from other p2p instances
  79. :param quic: Enables the QUIC transport
  80. :param tls: Enables TLS1.3 channel security protocol
  81. :param conn_manager: Enables the Connection Manager
  82. :param dht_mode: DHT mode (dht_client/dht_server/dht)
  83. :param force_reachability: Force reachability mode (public/private)
  84. :param nat_port_map: Enables NAT port mapping
  85. :param auto_nat: Enables the AutoNAT service
  86. :param use_relay: enables circuit relay
  87. :param use_relay_hop: enables hop for relay
  88. :param use_relay_discovery: enables passive discovery for relay
  89. :param use_auto_relay: enables autorelay
  90. :param relay_hop_limit: sets the hop limit for hop relays
  91. :param quiet: make the daemon process quiet
  92. :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
  93. :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
  94. (in particular, wait for ``ping_delay`` seconds before the first attempt)
  95. :return: a wrapper for the p2p daemon
  96. """
  97. assert not (initial_peers and use_ipfs), \
  98. 'User-defined initial_peers and use_ipfs=True are incompatible, please choose one option'
  99. self = cls()
  100. with path(cli, P2PD_FILENAME) as p:
  101. p2pd_path = p
  102. socket_uid = secrets.token_urlsafe(8)
  103. self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pd-{socket_uid}.sock')
  104. self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
  105. need_bootstrap = bool(initial_peers) or use_ipfs
  106. process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
  107. process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
  108. for param, value in [('bootstrapPeers', initial_peers),
  109. ('hostAddrs', host_maddrs),
  110. ('announceAddrs', announce_maddrs)]:
  111. if value:
  112. process_kwargs[param] = self._maddrs_to_str(value)
  113. proc_args = self._make_process_args(
  114. str(p2pd_path),
  115. listen=self._daemon_listen_maddr,
  116. quic=quic, tls=tls, connManager=conn_manager,
  117. natPortMap=nat_port_map, autonat=auto_nat,
  118. relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
  119. autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
  120. b=need_bootstrap, q=quiet, **process_kwargs)
  121. self._child = Popen(args=proc_args, encoding="utf8")
  122. self._alive = True
  123. self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
  124. await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
  125. self.persistent_streams = dict()
  126. return self
  127. async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
  128. for try_number in range(ping_n_attempts):
  129. await asyncio.sleep(ping_delay * (2 ** try_number))
  130. if self._child.poll() is not None: # Process died
  131. break
  132. try:
  133. await self._ping_daemon()
  134. break
  135. except Exception as e:
  136. if try_number == ping_n_attempts - 1:
  137. logger.exception('Failed to ping p2pd that has just started')
  138. await self.shutdown()
  139. raise
  140. if self._child.returncode is not None:
  141. raise RuntimeError(f'The p2p daemon has died with return code {self._child.returncode}')
  142. @classmethod
  143. async def replicate(cls, daemon_listen_maddr: Multiaddr) -> 'P2P':
  144. """
  145. Connect to existing p2p daemon
  146. :param daemon_listen_maddr: multiaddr of the existing p2p daemon
  147. :return: new wrapper for the existing p2p daemon
  148. """
  149. self = cls()
  150. # There is no child under control
  151. # Use external already running p2pd
  152. self._child = None
  153. self._alive = True
  154. socket_uid = secrets.token_urlsafe(8)
  155. self._daemon_listen_maddr = daemon_listen_maddr
  156. self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f'p2pclient-{socket_uid}.sock')
  157. self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
  158. await self._ping_daemon()
  159. return self
  160. async def _ping_daemon(self) -> None:
  161. self.id, self._visible_maddrs = await self._client.identify()
  162. logger.debug(f'Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}')
  163. async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
  164. """
  165. Get multiaddrs of the current peer that should be accessible by other peers.
  166. :param latest: ask the P2P daemon to refresh the visible multiaddrs
  167. """
  168. if latest:
  169. _, self._visible_maddrs = await self._client.identify()
  170. if not self._visible_maddrs:
  171. raise ValueError(f"No multiaddrs found for peer {self.id}")
  172. p2p_maddr = Multiaddr(f'/p2p/{self.id.to_base58()}')
  173. return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
  174. async def list_peers(self) -> List[PeerInfo]:
  175. return list(await self._client.list_peers())
  176. async def wait_for_at_least_n_peers(self, n_peers: int, attempts: int = 3, delay: float = 1) -> None:
  177. for _ in range(attempts):
  178. peers = await self._client.list_peers()
  179. if len(peers) >= n_peers:
  180. return
  181. await asyncio.sleep(delay)
  182. raise RuntimeError('Not enough peers')
  183. @property
  184. def daemon_listen_maddr(self) -> Multiaddr:
  185. return self._daemon_listen_maddr
  186. @staticmethod
  187. async def send_raw_data(data: bytes, writer: asyncio.StreamWriter) -> None:
  188. request = len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + data
  189. writer.write(request)
  190. @staticmethod
  191. async def send_msgpack(data: Any, writer: asyncio.StreamWriter) -> None:
  192. raw_data = MSGPackSerializer.dumps(data)
  193. await P2P.send_raw_data(raw_data, writer)
  194. @staticmethod
  195. async def send_protobuf(protobuf, out_proto_type: type, writer: asyncio.StreamWriter) -> None:
  196. if type(protobuf) != out_proto_type:
  197. raise TypeError('Unary handler returned protobuf of wrong type.')
  198. if out_proto_type == p2pd_pb2.RPCError:
  199. await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
  200. else:
  201. await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
  202. await P2P.send_raw_data(protobuf.SerializeToString(), writer)
  203. @staticmethod
  204. async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
  205. header = await reader.readexactly(P2P.HEADER_LEN)
  206. content_length = int.from_bytes(header, P2P.BYTEORDER)
  207. data = await reader.readexactly(content_length)
  208. return data
  209. @staticmethod
  210. async def receive_msgpack(reader: asyncio.StreamReader) -> Any:
  211. return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
  212. @staticmethod
  213. async def receive_protobuf(in_proto_type: type, reader: asyncio.StreamReader) -> \
  214. Tuple[Any, Optional[p2pd_pb2.RPCError]]:
  215. msg_type = await P2P.receive_raw_data(reader)
  216. if msg_type == P2P.RESULT_MESSAGE:
  217. protobuf = in_proto_type()
  218. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  219. return protobuf, None
  220. elif msg_type == P2P.ERROR_MESSAGE:
  221. protobuf = p2pd_pb2.RPCError()
  222. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  223. return None, protobuf
  224. else:
  225. raise TypeError('Invalid Protobuf message type')
  226. @staticmethod
  227. def _handle_stream(handle: Callable[[bytes], bytes]):
  228. async def do_handle_stream(
  229. stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
  230. try:
  231. request = await P2P.receive_raw_data(reader)
  232. except asyncio.IncompleteReadError:
  233. logger.debug("Incomplete read while receiving request from peer")
  234. writer.close()
  235. return
  236. try:
  237. result = handle(request)
  238. await P2P.send_raw_data(result, writer)
  239. finally:
  240. writer.close()
  241. return do_handle_stream
  242. def _handle_unary_stream(self, handle: Callable[[Any, P2PContext], Any], handle_name: str,
  243. in_proto_type: type, out_proto_type: type):
  244. async def watchdog(reader: asyncio.StreamReader) -> None:
  245. await reader.read(n=1)
  246. raise P2PInterruptedError()
  247. async def do_handle_unary_stream(stream_info: StreamInfo,
  248. reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
  249. try:
  250. try:
  251. request, err = await P2P.receive_protobuf(in_proto_type, reader)
  252. except asyncio.IncompleteReadError:
  253. logger.debug(f'Incomplete read while receiving request from peer in {handle_name}')
  254. return
  255. except google.protobuf.message.DecodeError as error:
  256. logger.debug(f'Failed to decode request protobuf '
  257. f'of type {in_proto_type} in {handle_name}: {error}')
  258. return
  259. if err is not None:
  260. logger.debug(f'Got an error instead of a request in {handle_name}: {err}')
  261. context = P2PContext(handle_name=handle_name, local_id=self.id,
  262. remote_id=stream_info.peer_id, remote_maddr=stream_info.addr)
  263. done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
  264. return_when=asyncio.FIRST_COMPLETED)
  265. try:
  266. result = done.pop().result()
  267. await P2P.send_protobuf(result, out_proto_type, writer)
  268. except P2PInterruptedError:
  269. pass
  270. except Exception as exc:
  271. error = p2pd_pb2.RPCError(message=str(exc))
  272. await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
  273. finally:
  274. if pending:
  275. for task in pending:
  276. task.cancel()
  277. await asyncio.wait(pending)
  278. finally:
  279. writer.close()
  280. return do_handle_unary_stream
  281. def _start_listening(self) -> None:
  282. async def listen() -> None:
  283. async with self._client.listen():
  284. await self._server_stopped.wait()
  285. self._listen_task = asyncio.create_task(listen())
  286. async def _stop_listening(self) -> None:
  287. if self._listen_task is not None:
  288. self._server_stopped.set()
  289. self._listen_task.cancel()
  290. try:
  291. await self._listen_task
  292. except asyncio.CancelledError:
  293. self._listen_task = None
  294. self._server_stopped.clear()
  295. async def add_stream_handler(self, name: str, handle: Callable[[bytes], bytes]) -> None:
  296. if self._listen_task is None:
  297. self._start_listening()
  298. await self._client.stream_handler(name, self._handle_stream(handle))
  299. async def add_unary_handler(self, name: str, handle: Callable[[Any, P2PContext], Any],
  300. in_proto_type: type, out_proto_type: type) -> None:
  301. if self._listen_task is None:
  302. self._start_listening()
  303. await self._client.stream_handler(
  304. name, self._handle_unary_stream(handle, name, in_proto_type, out_proto_type))
  305. async def call_peer_handler(self, peer_id: PeerID, handler_name: str, input_data: bytes) -> bytes:
  306. stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
  307. try:
  308. await P2P.send_raw_data(input_data, writer)
  309. return await P2P.receive_raw_data(reader)
  310. finally:
  311. writer.close()
  312. async def create_persistent_stream(self, peer_id: PeerID, handler_name: str):
  313. if stream := self.persistent_streams.get((peer_id, handler_name)):
  314. return stream
  315. stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
  316. self.persistent_streams[(peer_id, handler_name)] = (stream_info, reader, writer)
  317. return stream_info, reader, writer
  318. async def call_unary_handler(self, peer_id: PeerID, handler_name: str,
  319. request_protobuf: Any, response_proto_type: type) -> Any:
  320. stream_info, reader, writer = await self._client.stream_open(peer_id, (handler_name,))
  321. try:
  322. await P2P.send_protobuf(request_protobuf, type(request_protobuf), writer)
  323. result, err = await P2P.receive_protobuf(response_proto_type, reader)
  324. if err is not None:
  325. raise P2PHandlerError(f'Failed to call unary handler {handler_name} at {peer_id}: {err.message}')
  326. return result
  327. finally:
  328. writer.close()
  329. def __del__(self):
  330. self._terminate()
  331. @property
  332. def is_alive(self) -> bool:
  333. return self._alive
  334. async def shutdown(self) -> None:
  335. await self._stop_listening()
  336. await asyncio.get_event_loop().run_in_executor(None, self._terminate)
  337. def _terminate(self) -> None:
  338. self._alive = False
  339. if self._child is not None and self._child.poll() is None:
  340. self._child.terminate()
  341. self._child.wait()
  342. logger.debug(f'Terminated p2pd with id = {self.id}')
  343. with suppress(FileNotFoundError):
  344. os.remove(self._daemon_listen_maddr['unix'])
  345. with suppress(FileNotFoundError):
  346. os.remove(self._client_listen_maddr['unix'])
  347. @staticmethod
  348. def _make_process_args(*args, **kwargs) -> List[str]:
  349. proc_args = []
  350. proc_args.extend(
  351. str(entry) for entry in args
  352. )
  353. proc_args.extend(
  354. f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
  355. for key, value in kwargs.items()
  356. )
  357. return proc_args
  358. @staticmethod
  359. def _convert_process_arg_type(val: Any) -> Any:
  360. if isinstance(val, bool):
  361. return int(val)
  362. return val
  363. @staticmethod
  364. def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
  365. return ','.join(str(addr) for addr in maddrs)
  366. class P2PInterruptedError(Exception):
  367. pass
  368. class P2PHandlerError(Exception):
  369. pass