p2p_daemon.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. import asyncio
  2. import os
  3. import secrets
  4. from collections.abc import AsyncIterable as AsyncIterableABC
  5. from contextlib import closing, suppress
  6. from dataclasses import dataclass
  7. from importlib.resources import path
  8. from subprocess import Popen
  9. from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
  10. from google.protobuf.message import Message
  11. from multiaddr import Multiaddr
  12. import hivemind.hivemind_cli as cli
  13. import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
  14. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
  15. from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
  16. from hivemind.proto.p2pd_pb2 import RPCError
  17. from hivemind.utils.asyncio import aiter
  18. from hivemind.utils.logging import get_logger
  19. logger = get_logger(__name__)
  20. P2PD_FILENAME = "p2pd"
  21. @dataclass(frozen=True)
  22. class P2PContext(object):
  23. handle_name: str
  24. local_id: PeerID
  25. remote_id: PeerID = None
  26. class P2P:
  27. """
  28. This class is responsible for establishing peer-to-peer connections through NAT and/or firewalls.
  29. It creates and manages a libp2p daemon (https://libp2p.io) in a background process,
  30. then terminates it when P2P is shut down. In order to communicate, a P2P instance should
  31. either use one or more initial_peers that will connect it to the rest of the swarm or
  32. use the public IPFS network (https://ipfs.io).
  33. For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
  34. - `P2P.add_protobuf_handler` accepts a protobuf message and returns another protobuf
  35. - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
  36. To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
  37. using the recipient's unique `P2P.id` and the name of the corresponding handler.
  38. """
  39. HEADER_LEN = 8
  40. BYTEORDER = "big"
  41. MESSAGE_MARKER = b"\x00"
  42. ERROR_MARKER = b"\x01"
  43. END_OF_STREAM = RPCError()
  44. DHT_MODE_MAPPING = {
  45. "dht": {"dht": 1},
  46. "dht_server": {"dhtServer": 1},
  47. "dht_client": {"dhtClient": 1},
  48. }
  49. FORCE_REACHABILITY_MAPPING = {
  50. "public": {"forceReachabilityPublic": 1},
  51. "private": {"forceReachabilityPrivate": 1},
  52. }
  53. _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
  54. def __init__(self):
  55. self.id = None
  56. self._child = None
  57. self._alive = False
  58. self._listen_task = None
  59. self._server_stopped = asyncio.Event()
  60. @classmethod
  61. async def create(
  62. cls,
  63. initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
  64. use_ipfs: bool = False,
  65. host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
  66. announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
  67. quic: bool = True,
  68. tls: bool = True,
  69. conn_manager: bool = True,
  70. dht_mode: str = "dht_server",
  71. force_reachability: Optional[str] = None,
  72. nat_port_map: bool = True,
  73. auto_nat: bool = True,
  74. use_relay: bool = True,
  75. use_relay_hop: bool = False,
  76. use_relay_discovery: bool = False,
  77. use_auto_relay: bool = False,
  78. relay_hop_limit: int = 0,
  79. quiet: bool = True,
  80. ping_n_attempts: int = 5,
  81. ping_delay: float = 0.4,
  82. ) -> "P2P":
  83. """
  84. Start a new p2pd process and connect to it.
  85. :param initial_peers: List of bootstrap peers
  86. :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
  87. :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
  88. :param announce_maddrs: Visible multiaddrs that the peer will announce
  89. for external connections from other p2p instances
  90. :param quic: Enables the QUIC transport
  91. :param tls: Enables TLS1.3 channel security protocol
  92. :param conn_manager: Enables the Connection Manager
  93. :param dht_mode: DHT mode (dht_client/dht_server/dht)
  94. :param force_reachability: Force reachability mode (public/private)
  95. :param nat_port_map: Enables NAT port mapping
  96. :param auto_nat: Enables the AutoNAT service
  97. :param use_relay: enables circuit relay
  98. :param use_relay_hop: enables hop for relay
  99. :param use_relay_discovery: enables passive discovery for relay
  100. :param use_auto_relay: enables autorelay
  101. :param relay_hop_limit: sets the hop limit for hop relays
  102. :param quiet: make the daemon process quiet
  103. :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
  104. :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
  105. (in particular, wait for ``ping_delay`` seconds before the first attempt)
  106. :return: a wrapper for the p2p daemon
  107. """
  108. assert not (
  109. initial_peers and use_ipfs
  110. ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
  111. self = cls()
  112. with path(cli, P2PD_FILENAME) as p:
  113. p2pd_path = p
  114. socket_uid = secrets.token_urlsafe(8)
  115. self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
  116. self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
  117. need_bootstrap = bool(initial_peers) or use_ipfs
  118. process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
  119. process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
  120. for param, value in [
  121. ("bootstrapPeers", initial_peers),
  122. ("hostAddrs", host_maddrs),
  123. ("announceAddrs", announce_maddrs),
  124. ]:
  125. if value:
  126. process_kwargs[param] = self._maddrs_to_str(value)
  127. proc_args = self._make_process_args(
  128. str(p2pd_path),
  129. listen=self._daemon_listen_maddr,
  130. quic=quic,
  131. tls=tls,
  132. connManager=conn_manager,
  133. natPortMap=nat_port_map,
  134. autonat=auto_nat,
  135. relay=use_relay,
  136. relayHop=use_relay_hop,
  137. relayDiscovery=use_relay_discovery,
  138. autoRelay=use_auto_relay,
  139. relayHopLimit=relay_hop_limit,
  140. b=need_bootstrap,
  141. q=quiet,
  142. **process_kwargs,
  143. )
  144. self._child = Popen(args=proc_args, encoding="utf8")
  145. self._alive = True
  146. self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
  147. await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
  148. return self
  149. async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
  150. for try_number in range(ping_n_attempts):
  151. await asyncio.sleep(ping_delay * (2 ** try_number))
  152. if self._child.poll() is not None: # Process died
  153. break
  154. try:
  155. await self._ping_daemon()
  156. break
  157. except Exception as e:
  158. if try_number == ping_n_attempts - 1:
  159. logger.exception("Failed to ping p2pd that has just started")
  160. await self.shutdown()
  161. raise
  162. if self._child.returncode is not None:
  163. raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
  164. @classmethod
  165. async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
  166. """
  167. Connect to existing p2p daemon
  168. :param daemon_listen_maddr: multiaddr of the existing p2p daemon
  169. :return: new wrapper for the existing p2p daemon
  170. """
  171. self = cls()
  172. # There is no child under control
  173. # Use external already running p2pd
  174. self._child = None
  175. self._alive = True
  176. socket_uid = secrets.token_urlsafe(8)
  177. self._daemon_listen_maddr = daemon_listen_maddr
  178. self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
  179. self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
  180. await self._ping_daemon()
  181. return self
  182. async def _ping_daemon(self) -> None:
  183. self.id, self._visible_maddrs = await self._client.identify()
  184. logger.debug(f"Launched p2pd with id = {self.id}, host multiaddrs = {self._visible_maddrs}")
  185. async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
  186. """
  187. Get multiaddrs of the current peer that should be accessible by other peers.
  188. :param latest: ask the P2P daemon to refresh the visible multiaddrs
  189. """
  190. if latest:
  191. _, self._visible_maddrs = await self._client.identify()
  192. if not self._visible_maddrs:
  193. raise ValueError(f"No multiaddrs found for peer {self.id}")
  194. p2p_maddr = Multiaddr(f"/p2p/{self.id.to_base58()}")
  195. return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
  196. async def list_peers(self) -> List[PeerInfo]:
  197. return list(await self._client.list_peers())
  198. async def wait_for_at_least_n_peers(self, n_peers: int, attempts: int = 3, delay: float = 1) -> None:
  199. for _ in range(attempts):
  200. peers = await self._client.list_peers()
  201. if len(peers) >= n_peers:
  202. return
  203. await asyncio.sleep(delay)
  204. raise RuntimeError("Not enough peers")
  205. @property
  206. def daemon_listen_maddr(self) -> Multiaddr:
  207. return self._daemon_listen_maddr
  208. @staticmethod
  209. async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
  210. writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
  211. data = memoryview(data)
  212. for offset in range(0, len(data), chunk_size):
  213. writer.write(data[offset : offset + chunk_size])
  214. await writer.drain()
  215. @staticmethod
  216. async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
  217. header = await reader.readexactly(P2P.HEADER_LEN)
  218. content_length = int.from_bytes(header, P2P.BYTEORDER)
  219. data = await reader.readexactly(content_length)
  220. return data
  221. TInputProtobuf = TypeVar("TInputProtobuf")
  222. TOutputProtobuf = TypeVar("TOutputProtobuf")
  223. @staticmethod
  224. async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
  225. if isinstance(protobuf, RPCError):
  226. writer.write(P2P.ERROR_MARKER)
  227. else:
  228. writer.write(P2P.MESSAGE_MARKER)
  229. await P2P.send_raw_data(protobuf.SerializeToString(), writer)
  230. @staticmethod
  231. async def receive_protobuf(
  232. input_protobuf_type: Message, reader: asyncio.StreamReader
  233. ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
  234. msg_type = await reader.readexactly(1)
  235. if msg_type == P2P.MESSAGE_MARKER:
  236. protobuf = input_protobuf_type()
  237. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  238. return protobuf, None
  239. elif msg_type == P2P.ERROR_MARKER:
  240. protobuf = RPCError()
  241. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  242. return None, protobuf
  243. else:
  244. raise TypeError("Invalid Protobuf message type")
  245. TInputStream = AsyncIterator[TInputProtobuf]
  246. TOutputStream = AsyncIterator[TOutputProtobuf]
  247. async def _add_protobuf_stream_handler(
  248. self,
  249. name: str,
  250. handler: Callable[[TInputStream, P2PContext], TOutputStream],
  251. input_protobuf_type: Message,
  252. max_prefetch: int = 5,
  253. ) -> None:
  254. """
  255. :param max_prefetch: Maximum number of items to prefetch from the request stream.
  256. ``max_prefetch <= 0`` means unlimited.
  257. :note: Since the cancel messages are sent via the input stream,
  258. they will not be received while the prefetch buffer is full.
  259. """
  260. if self._listen_task is None:
  261. self._start_listening()
  262. async def _handle_stream(
  263. stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
  264. ) -> None:
  265. context = P2PContext(
  266. handle_name=name,
  267. local_id=self.id,
  268. remote_id=stream_info.peer_id,
  269. )
  270. requests = asyncio.Queue(max_prefetch)
  271. async def _read_stream() -> P2P.TInputStream:
  272. while True:
  273. request = await requests.get()
  274. if request is None:
  275. break
  276. yield request
  277. async def _process_stream() -> None:
  278. try:
  279. async for response in handler(_read_stream(), context):
  280. await P2P.send_protobuf(response, writer)
  281. except Exception as e:
  282. logger.warning("Exception while processing stream and sending responses:", exc_info=True)
  283. await P2P.send_protobuf(RPCError(message=str(e)), writer)
  284. with closing(writer):
  285. processing_task = asyncio.create_task(_process_stream())
  286. try:
  287. while True:
  288. receive_task = asyncio.create_task(P2P.receive_protobuf(input_protobuf_type, reader))
  289. await asyncio.wait({processing_task, receive_task}, return_when=asyncio.FIRST_COMPLETED)
  290. if processing_task.done():
  291. receive_task.cancel()
  292. return
  293. if receive_task.done():
  294. try:
  295. request, _ = await receive_task
  296. except asyncio.IncompleteReadError: # Connection is closed (the client cancelled or died)
  297. return
  298. await requests.put(request) # `request` is None for the end-of-stream message
  299. except Exception:
  300. logger.warning("Exception while receiving requests:", exc_info=True)
  301. finally:
  302. processing_task.cancel()
  303. await self._client.stream_handler(name, _handle_stream)
  304. async def _iterate_protobuf_stream_handler(
  305. self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
  306. ) -> TOutputStream:
  307. _, reader, writer = await self._client.stream_open(peer_id, (name,))
  308. async def _write_to_stream() -> None:
  309. async for request in requests:
  310. await P2P.send_protobuf(request, writer)
  311. await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
  312. with closing(writer):
  313. writing_task = asyncio.create_task(_write_to_stream())
  314. try:
  315. while True:
  316. try:
  317. response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
  318. except asyncio.IncompleteReadError: # Connection is closed
  319. break
  320. if err is not None:
  321. raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
  322. yield response
  323. await writing_task
  324. finally:
  325. writing_task.cancel()
  326. async def add_protobuf_handler(
  327. self,
  328. name: str,
  329. handler: Callable[
  330. [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
  331. ],
  332. input_protobuf_type: Message,
  333. *,
  334. stream_input: bool = False,
  335. stream_output: bool = False,
  336. ) -> None:
  337. """
  338. :param stream_input: If True, assume ``handler`` to take ``TInputStream``
  339. (not just ``TInputProtobuf``) as input.
  340. :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
  341. """
  342. if not stream_input and not stream_output:
  343. await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
  344. return
  345. async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
  346. if stream_input:
  347. input = requests
  348. else:
  349. count = 0
  350. async for input in requests:
  351. count += 1
  352. if count != 1:
  353. raise ValueError(f"Got {count} requests for handler {name} instead of one")
  354. output = handler(input, context)
  355. if isinstance(output, AsyncIterableABC):
  356. async for item in output:
  357. yield item
  358. else:
  359. yield await output
  360. await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
  361. # only registers request-response handlers
  362. async def _add_protobuf_unary_handler(
  363. self,
  364. handle_name: str,
  365. handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
  366. input_protobuf_type: Message,
  367. ) -> None:
  368. """
  369. Register a request-response (unary) handler. Unary requests and responses
  370. are sent through persistent multiplexed connections to the daemon for the
  371. sake of reducing the number of open files.
  372. :param handle_name: name of the handler (protocol id)
  373. :param handler: function handling the unary requests
  374. :param input_protobuf_type: protobuf type of the request
  375. """
  376. async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
  377. input_serialized = input_protobuf_type.FromString(request)
  378. context = P2PContext(
  379. handle_name=handle_name,
  380. local_id=self.id,
  381. remote_id=remote_id,
  382. )
  383. response = await handler(input_serialized, context)
  384. return response.SerializeToString()
  385. await self._client.add_unary_handler(handle_name, _unary_handler)
  386. async def call_protobuf_handler(
  387. self,
  388. peer_id: PeerID,
  389. name: str,
  390. input: Union[TInputProtobuf, TInputStream],
  391. output_protobuf_type: Message,
  392. ) -> Awaitable[TOutputProtobuf]:
  393. if not isinstance(input, AsyncIterableABC):
  394. return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
  395. responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
  396. count = 0
  397. async for response in responses:
  398. count += 1
  399. if count != 1:
  400. raise ValueError(f"Got {count} responses from handler {name} instead of one")
  401. return response
  402. async def _call_unary_protobuf_handler(
  403. self,
  404. peer_id: PeerID,
  405. handle_name: str,
  406. input: TInputProtobuf,
  407. output_protobuf_type: Message,
  408. ) -> Awaitable[TOutputProtobuf]:
  409. serialized_input = input.SerializeToString()
  410. response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
  411. return output_protobuf_type().FromString(response)
  412. def iterate_protobuf_handler(
  413. self,
  414. peer_id: PeerID,
  415. name: str,
  416. input: Union[TInputProtobuf, TInputStream],
  417. output_protobuf_type: Message,
  418. ) -> TOutputStream:
  419. requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
  420. return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
  421. def _start_listening(self) -> None:
  422. async def listen() -> None:
  423. async with self._client.listen():
  424. await self._server_stopped.wait()
  425. self._listen_task = asyncio.create_task(listen())
  426. async def _stop_listening(self) -> None:
  427. if self._listen_task is not None:
  428. self._server_stopped.set()
  429. self._listen_task.cancel()
  430. try:
  431. await self._listen_task
  432. except asyncio.CancelledError:
  433. self._listen_task = None
  434. self._server_stopped.clear()
  435. async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
  436. if self._listen_task is None:
  437. self._start_listening()
  438. await self._client.stream_handler(name, handler)
  439. async def call_binary_stream_handler(
  440. self, peer_id: PeerID, handler_name: str
  441. ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
  442. return await self._client.stream_open(peer_id, (handler_name,))
  443. def __del__(self):
  444. self._terminate()
  445. @property
  446. def is_alive(self) -> bool:
  447. return self._alive
  448. async def shutdown(self) -> None:
  449. await self._stop_listening()
  450. await asyncio.get_event_loop().run_in_executor(None, self._terminate)
  451. def _terminate(self) -> None:
  452. self._alive = False
  453. if self._client.control._write_task is not None:
  454. self._client.control._write_task.cancel()
  455. if self._client.control._read_task is not None:
  456. self._client.control._read_task.cancel()
  457. if self._child is not None and self._child.poll() is None:
  458. self._child.terminate()
  459. self._child.wait()
  460. logger.debug(f"Terminated p2pd with id = {self.id}")
  461. with suppress(FileNotFoundError):
  462. os.remove(self._daemon_listen_maddr["unix"])
  463. with suppress(FileNotFoundError):
  464. os.remove(self._client_listen_maddr["unix"])
  465. @staticmethod
  466. def _make_process_args(*args, **kwargs) -> List[str]:
  467. proc_args = []
  468. proc_args.extend(str(entry) for entry in args)
  469. proc_args.extend(
  470. f"-{key}={P2P._convert_process_arg_type(value)}" if value is not None else f"-{key}"
  471. for key, value in kwargs.items()
  472. )
  473. return proc_args
  474. @staticmethod
  475. def _convert_process_arg_type(val: Any) -> Any:
  476. if isinstance(val, bool):
  477. return int(val)
  478. return val
  479. @staticmethod
  480. def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
  481. return ",".join(str(addr) for addr in maddrs)
  482. class P2PInterruptedError(Exception):
  483. pass