p2p_daemon.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. import asyncio
  2. from copy import deepcopy
  3. from dataclasses import dataclass
  4. from importlib.resources import path
  5. from subprocess import Popen
  6. from typing import List, Optional
  7. import google.protobuf
  8. from multiaddr import Multiaddr
  9. import hivemind.hivemind_cli as cli
  10. import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
  11. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo
  12. from hivemind.proto import p2pd_pb2
  13. from hivemind.utils import MSGPackSerializer
  14. from hivemind.utils.logging import get_logger
  15. from hivemind.utils.networking import find_open_port
  16. logger = get_logger(__name__)
  17. P2PD_FILENAME = 'p2pd'
  18. NUM_RETRIES = 3
  19. RETRY_DELAY = 0.4
  20. class P2PInterruptedError(Exception):
  21. pass
  22. @dataclass(frozen=False)
  23. class P2PContext(object):
  24. id: str
  25. port: int
  26. handle_name: str
  27. peer_id: PeerID = None
  28. peer_addr: Multiaddr = None
  29. class P2P:
  30. """
  31. Forks a child process and executes p2pd command with given arguments.
  32. Can be used for peer to peer communication and procedure calls.
  33. Sends SIGKILL to the child in destructor.
  34. """
  35. HEADER_LEN = 8
  36. BYTEORDER = 'big'
  37. PB_HEADER_LEN = 1
  38. RESULT_MESSAGE = b'\x00'
  39. ERROR_MESSAGE = b'\x01'
  40. DHT_MODE_MAPPING = {
  41. 'dht': {'dht': 1},
  42. 'dht_server': {'dhtServer': 1},
  43. 'dht_client': {'dhtClient': 1},
  44. }
  45. FORCE_REACHABILITY_MAPPING = {
  46. 'public': {'forceReachabilityPublic': 1},
  47. 'private': {'forceReachabilityPrivate': 1},
  48. }
  49. def __init__(self):
  50. self._child = None
  51. self._alive = False
  52. self._listen_task = None
  53. self._server_stopped = asyncio.Event()
  54. @classmethod
  55. async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True,
  56. dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
  57. nat_port_map: bool = True, auto_nat: bool = True, bootstrap: bool = False,
  58. bootstrap_peers: Optional[List[str]] = None, use_global_ipfs: bool = False, host_port: int = None,
  59. daemon_listen_port: int = None, **kwargs):
  60. """
  61. Start a new p2pd process and connect to it.
  62. :param args:
  63. :param quic: Enables the QUIC transport
  64. :param tls: Enables TLS1.3 channel security protocol
  65. :param conn_manager: Enables the Connection Manager
  66. :param dht_mode: DHT mode (dht_client/dht_server/dht)
  67. :param force_reachability: Force reachability mode (public/private)
  68. :param nat_port_map: Enables NAT port mapping
  69. :param auto_nat: Enables the AutoNAT service
  70. :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
  71. :param bootstrap_peers: List of bootstrap peers; defaults to the IPFS DHT peers
  72. :param use_global_ipfs: Bootstrap to global ipfs (works only if bootstrap=True and bootstrap_peers=None)
  73. :param host_port: port for p2p network
  74. :param daemon_listen_port: port for connection daemon and client binding
  75. :param kwargs:
  76. :return: new wrapper for p2p daemon
  77. """
  78. assert not (bootstrap and bootstrap_peers is None and not use_global_ipfs), \
  79. 'Trying to create with bootstrap node without bootstrap nodes list. ' \
  80. 'It is very dangerous, because p2pd connects to global ipfs and it is very unstable. ' \
  81. 'If you really want this, pass use_global_ipfs=True'
  82. assert not (bootstrap_peers is not None and use_global_ipfs), \
  83. 'Non empty bootstrap_nodes and use_global_ipfs=True are incompatible.' \
  84. 'Choose one option: your nodes list (preferable) or global ipfs (very unstable)'
  85. self = cls()
  86. with path(cli, P2PD_FILENAME) as p:
  87. p2pd_path = p
  88. bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
  89. dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
  90. force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
  91. proc_args = self._make_process_args(
  92. str(p2pd_path), *args,
  93. quic=quic, tls=tls, connManager=conn_manager,
  94. natPortMap=nat_port_map, autonat=auto_nat,
  95. b=bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs})
  96. self._assign_daemon_ports(host_port, daemon_listen_port)
  97. for try_count in range(NUM_RETRIES):
  98. try:
  99. self._initialize(proc_args)
  100. await self._wait_for_client(RETRY_DELAY * (2 ** try_count))
  101. break
  102. except Exception as e:
  103. logger.debug(f"Failed to initialize p2p daemon: {e}")
  104. self._terminate()
  105. if try_count == NUM_RETRIES - 1:
  106. raise
  107. self._assign_daemon_ports()
  108. return self
  109. @classmethod
  110. async def replicate(cls, daemon_listen_port: int, host_port: int):
  111. """
  112. Connect to existing p2p daemon
  113. :param daemon_listen_port: port for connection daemon and client binding
  114. :param host_port: port for p2p network
  115. :return: new wrapper for existing p2p daemon
  116. """
  117. self = cls()
  118. # There is no child under control
  119. # Use external already running p2pd
  120. self._child = None
  121. self._alive = True
  122. self._assign_daemon_ports(host_port, daemon_listen_port)
  123. self._client_listen_port = find_open_port()
  124. self._client = p2pclient.Client(
  125. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
  126. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
  127. await self._wait_for_client()
  128. return self
  129. async def wait_for_at_least_n_peers(self, n_peers, attempts=3, delay=1):
  130. for _ in range(attempts):
  131. peers = await self._client.list_peers()
  132. if len(peers) >= n_peers:
  133. return
  134. await asyncio.sleep(delay)
  135. raise RuntimeError('Not enough peers')
  136. def _initialize(self, proc_args: List[str]) -> None:
  137. proc_args = deepcopy(proc_args)
  138. proc_args.extend(self._make_process_args(
  139. hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
  140. listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
  141. ))
  142. self._child = Popen(args=proc_args, encoding="utf8")
  143. self._alive = True
  144. self._client_listen_port = find_open_port()
  145. self._client = p2pclient.Client(
  146. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
  147. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
  148. async def _wait_for_client(self, delay=0):
  149. await asyncio.sleep(delay)
  150. encoded = await self._client.identify()
  151. self.id = encoded[0].to_base58()
  152. def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
  153. if host_port is None:
  154. host_port = find_open_port()
  155. if daemon_listen_port is None:
  156. daemon_listen_port = find_open_port()
  157. while daemon_listen_port == host_port:
  158. daemon_listen_port = find_open_port()
  159. self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
  160. @staticmethod
  161. async def send_raw_data(byte_str, writer):
  162. request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
  163. writer.write(request)
  164. @staticmethod
  165. async def send_msgpack(data, writer):
  166. raw_data = MSGPackSerializer.dumps(data)
  167. await P2P.send_raw_data(raw_data, writer)
  168. @staticmethod
  169. async def send_protobuf(protobuf, out_proto_type, writer):
  170. if type(protobuf) != out_proto_type:
  171. raise TypeError('Unary handler returned protobuf of wrong type.')
  172. if out_proto_type == p2pd_pb2.RPCError:
  173. await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer)
  174. else:
  175. await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer)
  176. await P2P.send_raw_data(protobuf.SerializeToString(), writer)
  177. @staticmethod
  178. async def receive_raw_data(reader: asyncio.StreamReader, header_len=HEADER_LEN):
  179. header = await reader.readexactly(header_len)
  180. content_length = int.from_bytes(header, P2P.BYTEORDER)
  181. data = await reader.readexactly(content_length)
  182. return data
  183. @staticmethod
  184. async def receive_msgpack(reader):
  185. return MSGPackSerializer.loads(await P2P.receive_raw_data(reader))
  186. @staticmethod
  187. async def receive_protobuf(in_proto_type, reader):
  188. msg_type = await P2P.receive_raw_data(reader)
  189. if msg_type == P2P.RESULT_MESSAGE:
  190. protobuf = in_proto_type()
  191. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  192. return protobuf, None
  193. elif msg_type == P2P.ERROR_MESSAGE:
  194. protobuf = p2pd_pb2.RPCError()
  195. protobuf.ParseFromString(await P2P.receive_raw_data(reader))
  196. return None, protobuf
  197. else:
  198. raise TypeError('Invalid Protobuf message type')
  199. @staticmethod
  200. def _handle_stream(handle):
  201. async def do_handle_stream(stream_info, reader, writer):
  202. try:
  203. request = await P2P.receive_raw_data(reader)
  204. except asyncio.IncompleteReadError:
  205. logger.debug("Incomplete read while receiving request from peer")
  206. writer.close()
  207. return
  208. try:
  209. result = handle(request)
  210. await P2P.send_raw_data(result, writer)
  211. finally:
  212. writer.close()
  213. return do_handle_stream
  214. @staticmethod
  215. def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
  216. async def watchdog(reader: asyncio.StreamReader):
  217. await reader.read(n=1)
  218. raise P2PInterruptedError()
  219. async def do_handle_unary_stream(
  220. stream_info: StreamInfo,
  221. reader: asyncio.StreamReader,
  222. writer: asyncio.StreamWriter) -> None:
  223. try:
  224. try:
  225. request = await P2P.receive_protobuf(in_proto_type, reader)
  226. except asyncio.IncompleteReadError:
  227. logger.debug("Incomplete read while receiving request from peer")
  228. return
  229. except google.protobuf.message.DecodeError as error:
  230. logger.exception(error)
  231. return
  232. context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
  233. done, pending = await asyncio.wait([watchdog(reader), handle(request, context)],
  234. return_when=asyncio.FIRST_COMPLETED)
  235. try:
  236. result = done.pop().result()
  237. await P2P.send_protobuf(result, out_proto_type, writer)
  238. except P2PInterruptedError:
  239. pass
  240. except Exception as exc:
  241. error = p2pd_pb2.RPCError(message=str(exc))
  242. await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer)
  243. finally:
  244. pending_task = pending.pop()
  245. pending_task.cancel()
  246. try:
  247. await pending_task
  248. except asyncio.CancelledError:
  249. pass
  250. finally:
  251. writer.close()
  252. return do_handle_unary_stream
  253. def start_listening(self):
  254. async def listen():
  255. async with self._client.listen():
  256. await self._server_stopped.wait()
  257. self._listen_task = asyncio.create_task(listen())
  258. async def stop_listening(self):
  259. if self._listen_task is not None:
  260. self._server_stopped.set()
  261. self._listen_task.cancel()
  262. try:
  263. await self._listen_task
  264. except asyncio.CancelledError:
  265. self._listen_task = None
  266. self._server_stopped.clear()
  267. async def add_stream_handler(self, name, handle):
  268. if self._listen_task is None:
  269. self.start_listening()
  270. await self._client.stream_handler(name, self._handle_stream(handle))
  271. async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
  272. if self._listen_task is None:
  273. self.start_listening()
  274. context = P2PContext(id=self.id, port=self._host_port, handle_name=name)
  275. await self._client.stream_handler(
  276. name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
  277. async def call_peer_handler(self, peer_id, handler_name, input_data):
  278. libp2p_peer_id = PeerID.from_base58(peer_id)
  279. stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,))
  280. try:
  281. await P2P.send_raw_data(input_data, writer)
  282. return await P2P.receive_raw_data(reader)
  283. finally:
  284. writer.close()
  285. def __del__(self):
  286. self._terminate()
  287. @property
  288. def is_alive(self):
  289. return self._alive
  290. async def shutdown(self):
  291. await asyncio.get_event_loop().run_in_executor(None, self._terminate)
  292. def _terminate(self):
  293. self._alive = False
  294. if self._child is not None and self._child.poll() is None:
  295. self._child.kill()
  296. self._child.wait()
  297. @staticmethod
  298. def _make_process_args(*args, **kwargs) -> List[str]:
  299. proc_args = []
  300. proc_args.extend(
  301. str(entry) for entry in args
  302. )
  303. proc_args.extend(
  304. f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}'
  305. for key, value in kwargs.items()
  306. )
  307. return proc_args
  308. @staticmethod
  309. def _convert_process_arg_type(val):
  310. if isinstance(val, bool):
  311. return 1 if val else 0
  312. return val
  313. @staticmethod
  314. def _make_bootstrap_peers(nodes):
  315. if nodes is None:
  316. return {}
  317. return {'bootstrapPeers': ','.join(nodes)}