p2p_daemon.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import asyncio
  2. import copy
  3. from pathlib import Path
  4. import pickle
  5. import subprocess
  6. import typing as tp
  7. import warnings
  8. import google.protobuf
  9. from multiaddr import Multiaddr
  10. import p2pclient
  11. from libp2p.peer.id import ID
  12. from hivemind.utils.networking import find_open_port
  13. class P2PContext(object):
  14. def __init__(self, ours_id, ours_port, handle_name):
  15. self.peer_id = None
  16. self.peer_addr = None
  17. self.ours_id = ours_id
  18. self.ours_port = ours_port
  19. self.handle_name = handle_name
  20. class P2P(object):
  21. """
  22. Forks a child process and executes p2pd command with given arguments.
  23. Can be used for peer to peer communication and procedure calls.
  24. Sends SIGKILL to the child in destructor.
  25. """
  26. P2PD_RELATIVE_PATH = 'hivemind_cli/p2pd'
  27. NUM_RETRIES = 3
  28. RETRY_DELAY = 0.4
  29. HEADER_LEN = 8
  30. BYTEORDER = 'big'
  31. class IncompleteRead(Exception):
  32. pass
  33. class InterruptedError(Exception):
  34. pass
  35. def __init__(self):
  36. self._child = None
  37. self._listen_task = None
  38. self._server_stopped = asyncio.Event()
  39. @classmethod
  40. async def create(cls, *args, quic=1, tls=1, conn_manager=1, dht_client=1,
  41. nat_port_map=True, auto_nat=True, bootstrap=True,
  42. host_port: int = None, daemon_listen_port: int = None, **kwargs):
  43. self = cls()
  44. p2pd_path = Path(__file__).resolve().parents[1] / P2P.P2PD_RELATIVE_PATH
  45. proc_args = self._make_process_args(
  46. str(p2pd_path), *args,
  47. quic=quic, tls=tls, connManager=conn_manager,
  48. dhtClient=dht_client, natPortMap=nat_port_map,
  49. autonat=auto_nat, b=bootstrap, **kwargs)
  50. self._assign_daemon_ports(host_port, daemon_listen_port)
  51. for try_count in range(self.NUM_RETRIES):
  52. try:
  53. self._initialize(proc_args)
  54. await self._identify_client(P2P.RETRY_DELAY * (2 ** try_count))
  55. except Exception as exc:
  56. warnings.warn("Failed to initialize p2p daemon: " + str(exc), RuntimeWarning)
  57. self._kill_child()
  58. if try_count == P2P.NUM_RETRIES - 1:
  59. raise
  60. self._assign_daemon_ports()
  61. continue
  62. break
  63. return self
  64. def _initialize(self, proc_args: tp.List[str]) -> None:
  65. proc_args = copy.deepcopy(proc_args)
  66. proc_args.extend(self._make_process_args(
  67. hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic',
  68. listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
  69. ))
  70. self._child = subprocess.Popen(
  71. args=proc_args,
  72. stdin=subprocess.PIPE, stdout=subprocess.PIPE,
  73. stderr=subprocess.PIPE, encoding="utf8"
  74. )
  75. self._client_listen_port = find_open_port()
  76. self._client = p2pclient.Client(
  77. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
  78. Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
  79. async def _identify_client(self, delay):
  80. await asyncio.sleep(delay)
  81. encoded = await self._client.identify()
  82. self.id = encoded[0].to_base58()
  83. def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None):
  84. self._host_port, self._daemon_listen_port = host_port, daemon_listen_port
  85. if host_port is None:
  86. self._host_port = find_open_port()
  87. if daemon_listen_port is None:
  88. self._daemon_listen_port = find_open_port()
  89. while self._daemon_listen_port == self._host_port:
  90. self._daemon_listen_port = find_open_port()
  91. @staticmethod
  92. async def send_raw_data(byte_str, stream):
  93. request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str
  94. await stream.send_all(request)
  95. @staticmethod
  96. async def send_data(data, stream):
  97. await P2P.send_raw_data(pickle.dumps(data), stream)
  98. @staticmethod
  99. async def send_protobuf(protobuf, out_proto_type, stream):
  100. if type(protobuf) != out_proto_type:
  101. error = TypeError('Unary handler returned protobuf of wrong type.')
  102. await P2P.send_raw_data(pickle.dumps(error), stream)
  103. raise error
  104. await P2P.send_raw_data(protobuf.SerializeToString(), stream)
  105. @staticmethod
  106. async def receive_exactly(stream, n_bytes, max_bytes=1 << 16):
  107. buffer = bytearray()
  108. while len(buffer) < n_bytes:
  109. data = await stream.receive_some(min(max_bytes, n_bytes - len(buffer)))
  110. if len(data) == 0:
  111. raise P2P.IncompleteRead()
  112. buffer.extend(data)
  113. return bytes(buffer)
  114. @staticmethod
  115. async def receive_raw_data(stream):
  116. header = await P2P.receive_exactly(stream, P2P.HEADER_LEN)
  117. content_length = int.from_bytes(header, P2P.BYTEORDER)
  118. data = await P2P.receive_exactly(stream, content_length)
  119. return data
  120. @staticmethod
  121. async def receive_data(stream):
  122. return pickle.loads(await P2P.receive_raw_data(stream))
  123. @staticmethod
  124. async def receive_protobuf(in_proto_type, stream):
  125. protobuf = in_proto_type()
  126. protobuf.ParseFromString(await P2P.receive_raw_data(stream))
  127. return protobuf
  128. @staticmethod
  129. def _handle_stream(handle):
  130. async def do_handle_stream(stream_info, stream):
  131. try:
  132. request = await P2P.receive_data(stream)
  133. except P2P.IncompleteRead:
  134. warnings.warn("Incomplete read while receiving request from peer", RuntimeWarning)
  135. await stream.close()
  136. return
  137. try:
  138. result = handle(request)
  139. await P2P.send_data(result, stream)
  140. except Exception as exc:
  141. await P2P.send_data(exc, stream)
  142. finally:
  143. await stream.close()
  144. return do_handle_stream
  145. @staticmethod
  146. def _handle_unary_stream(handle, context, in_proto_type, out_proto_type):
  147. async def watchdog(stream):
  148. await stream.receive_some(max_bytes=1)
  149. raise P2P.InterruptedError()
  150. async def do_handle_unary_stream(stream_info, stream):
  151. try:
  152. try:
  153. request = await P2P.receive_protobuf(in_proto_type, stream)
  154. except P2P.IncompleteRead:
  155. warnings.warn("Incomplete read while receiving request from peer",
  156. RuntimeWarning)
  157. return
  158. except google.protobuf.message.DecodeError as error:
  159. warnings.warn(repr(error), RuntimeWarning)
  160. return
  161. context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr
  162. done, pending = await asyncio.wait([watchdog(stream), handle(request, context)],
  163. return_when=asyncio.FIRST_COMPLETED)
  164. try:
  165. result = done.pop().result()
  166. await P2P.send_protobuf(result, out_proto_type, stream)
  167. except P2P.InterruptedError:
  168. pass
  169. except Exception as exc:
  170. await P2P.send_data(exc, stream)
  171. finally:
  172. pending_task = pending.pop()
  173. pending_task.cancel()
  174. try:
  175. await pending_task
  176. except asyncio.CancelledError:
  177. pass
  178. finally:
  179. await stream.close()
  180. return do_handle_unary_stream
  181. def start_listening(self):
  182. async def listen():
  183. async with self._client.listen():
  184. await self._server_stopped.wait()
  185. self._listen_task = asyncio.create_task(listen())
  186. async def stop_listening(self):
  187. if self._listen_task is not None:
  188. self._server_stopped.set()
  189. self._listen_task.cancel()
  190. try:
  191. await self._listen_task
  192. except asyncio.CancelledError:
  193. self._listen_task = None
  194. self._server_stopped.clear()
  195. async def add_stream_handler(self, name, handle):
  196. if self._listen_task is None:
  197. self.start_listening()
  198. await self._client.stream_handler(name, P2P._handle_stream(handle))
  199. async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type):
  200. if self._listen_task is None:
  201. self.start_listening()
  202. context = P2PContext(ours_id=self.id, ours_port=self._host_port, handle_name=name)
  203. await self._client.stream_handler(
  204. name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type))
  205. async def call_peer_handler(self, peer_id, handler_name, input_data):
  206. libp2p_peer_id = ID.from_base58(peer_id)
  207. stream_info, stream = await self._client.stream_open(libp2p_peer_id, (handler_name,))
  208. try:
  209. await P2P.send_data(input_data, stream)
  210. return await P2P.receive_data(stream)
  211. finally:
  212. await stream.close()
  213. def __del__(self):
  214. self._kill_child()
  215. def _kill_child(self):
  216. if self._child is not None and self._child.poll() is None:
  217. self._child.kill()
  218. self._child.wait()
  219. def _make_process_args(self, *args, **kwargs) -> tp.List[str]:
  220. proc_args = []
  221. proc_args.extend(
  222. str(entry) for entry in args
  223. )
  224. proc_args.extend(
  225. f'-{key}={value}' if value is not None else f'-{key}'
  226. for key, value in kwargs.items()
  227. )
  228. return proc_args