123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551 |
- import asyncio
- import os
- import secrets
- from collections.abc import AsyncIterable as AsyncIterableABC
- from contextlib import closing, suppress
- from dataclasses import dataclass
- from importlib.resources import path
- from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
- from google.protobuf.message import Message
- from multiaddr import Multiaddr
- import hivemind.hivemind_cli as cli
- import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
- from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
- from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
- from hivemind.proto.p2pd_pb2 import RPCError
- from hivemind.utils.asyncio import aiter, asingle
- from hivemind.utils.logging import get_logger
- logger = get_logger(__name__)
- P2PD_FILENAME = "p2pd"
- @dataclass(frozen=True)
- class P2PContext(object):
- handle_name: str
- local_id: PeerID
- remote_id: PeerID = None
- class P2P:
- """
- This class is responsible for establishing peer-to-peer connections through NAT and/or firewalls.
- It creates and manages a libp2p daemon (https://libp2p.io) in a background process,
- then terminates it when P2P is shut down. In order to communicate, a P2P instance should
- either use one or more initial_peers that will connect it to the rest of the swarm or
- use the public IPFS network (https://ipfs.io).
- For incoming connections, P2P instances add RPC handlers that may be accessed by other peers:
- - `P2P.add_protobuf_handler` accepts a protobuf message and returns another protobuf
- - `P2P.add_binary_stream_handler` transfers raw data using bi-directional streaming interface
- To access these handlers, a P2P instance can `P2P.call_protobuf_handler`/`P2P.call_binary_stream_handler`,
- using the recipient's unique `P2P.peer_id` and the name of the corresponding handler.
- """
- HEADER_LEN = 8
- BYTEORDER = "big"
- MESSAGE_MARKER = b"\x00"
- ERROR_MARKER = b"\x01"
- END_OF_STREAM = RPCError()
- DHT_MODE_MAPPING = {
- "dht": {"dht": 1},
- "dht_server": {"dhtServer": 1},
- "dht_client": {"dhtClient": 1},
- }
- FORCE_REACHABILITY_MAPPING = {
- "public": {"forceReachabilityPublic": 1},
- "private": {"forceReachabilityPrivate": 1},
- }
- _UNIX_SOCKET_PREFIX = "/unix/tmp/hivemind-"
- def __init__(self):
- self.peer_id = None
- self._child = None
- self._alive = False
- self._reader_task = None
- self._listen_task = None
- @classmethod
- async def create(
- cls,
- initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
- use_ipfs: bool = False,
- host_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = ("/ip4/127.0.0.1/tcp/0",),
- announce_maddrs: Optional[Sequence[Union[Multiaddr, str]]] = None,
- quic: bool = True,
- tls: bool = True,
- conn_manager: bool = True,
- dht_mode: str = "dht_server",
- force_reachability: Optional[str] = None,
- nat_port_map: bool = True,
- auto_nat: bool = True,
- use_relay: bool = True,
- use_relay_hop: bool = False,
- use_relay_discovery: bool = False,
- use_auto_relay: bool = False,
- relay_hop_limit: int = 0,
- startup_timeout: float = 15,
- ) -> "P2P":
- """
- Start a new p2pd process and connect to it.
- :param initial_peers: List of bootstrap peers
- :param use_ipfs: Bootstrap to IPFS (incompatible with initial_peers)
- :param host_maddrs: Multiaddrs to listen for external connections from other p2p instances
- :param announce_maddrs: Visible multiaddrs that the peer will announce
- for external connections from other p2p instances
- :param quic: Enables the QUIC transport
- :param tls: Enables TLS1.3 channel security protocol
- :param conn_manager: Enables the Connection Manager
- :param dht_mode: DHT mode (dht_client/dht_server/dht)
- :param force_reachability: Force reachability mode (public/private)
- :param nat_port_map: Enables NAT port mapping
- :param auto_nat: Enables the AutoNAT service
- :param use_relay: enables circuit relay
- :param use_relay_hop: enables hop for relay
- :param use_relay_discovery: enables passive discovery for relay
- :param use_auto_relay: enables autorelay
- :param relay_hop_limit: sets the hop limit for hop relays
- :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
- :return: a wrapper for the p2p daemon
- """
- assert not (
- initial_peers and use_ipfs
- ), "User-defined initial_peers and use_ipfs=True are incompatible, please choose one option"
- self = cls()
- with path(cli, P2PD_FILENAME) as p:
- p2pd_path = p
- socket_uid = secrets.token_urlsafe(8)
- self._daemon_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pd-{socket_uid}.sock")
- self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
- need_bootstrap = bool(initial_peers) or use_ipfs
- process_kwargs = cls.DHT_MODE_MAPPING.get(dht_mode, {"dht": 0})
- process_kwargs.update(cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}))
- for param, value in [
- ("bootstrapPeers", initial_peers),
- ("hostAddrs", host_maddrs),
- ("announceAddrs", announce_maddrs),
- ]:
- if value:
- process_kwargs[param] = self._maddrs_to_str(value)
- proc_args = self._make_process_args(
- str(p2pd_path),
- listen=self._daemon_listen_maddr,
- quic=quic,
- tls=tls,
- connManager=conn_manager,
- natPortMap=nat_port_map,
- autonat=auto_nat,
- relay=use_relay,
- relayHop=use_relay_hop,
- relayDiscovery=use_relay_discovery,
- autoRelay=use_auto_relay,
- relayHopLimit=relay_hop_limit,
- b=need_bootstrap,
- **process_kwargs,
- )
- self._child = await asyncio.subprocess.create_subprocess_exec(
- *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
- )
- self._alive = True
- ready = asyncio.Future()
- self._reader_task = asyncio.create_task(self._read_outputs(ready))
- try:
- await asyncio.wait_for(ready, startup_timeout)
- except asyncio.TimeoutError:
- await self.shutdown()
- raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
- self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
- await self._ping_daemon()
- return self
- @classmethod
- async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
- """
- Connect to existing p2p daemon
- :param daemon_listen_maddr: multiaddr of the existing p2p daemon
- :return: new wrapper for the existing p2p daemon
- """
- self = cls()
- # There is no child under control
- # Use external already running p2pd
- self._child = None
- self._alive = True
- socket_uid = secrets.token_urlsafe(8)
- self._daemon_listen_maddr = daemon_listen_maddr
- self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
- self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
- await self._ping_daemon()
- return self
- async def _ping_daemon(self) -> None:
- self.peer_id, self._visible_maddrs = await self._client.identify()
- logger.debug(f"Launched p2pd with peer id = {self.peer_id}, host multiaddrs = {self._visible_maddrs}")
- async def get_visible_maddrs(self, latest: bool = False) -> List[Multiaddr]:
- """
- Get multiaddrs of the current peer that should be accessible by other peers.
- :param latest: ask the P2P daemon to refresh the visible multiaddrs
- """
- if latest:
- _, self._visible_maddrs = await self._client.identify()
- if not self._visible_maddrs:
- raise ValueError(f"No multiaddrs found for peer {self.peer_id}")
- p2p_maddr = Multiaddr(f"/p2p/{self.peer_id.to_base58()}")
- return [addr.encapsulate(p2p_maddr) for addr in self._visible_maddrs]
- async def list_peers(self) -> List[PeerInfo]:
- return list(await self._client.list_peers())
- async def wait_for_at_least_n_peers(self, n_peers: int, attempts: int = 3, delay: float = 1) -> None:
- for _ in range(attempts):
- peers = await self._client.list_peers()
- if len(peers) >= n_peers:
- return
- await asyncio.sleep(delay)
- raise RuntimeError("Not enough peers")
- @property
- def daemon_listen_maddr(self) -> Multiaddr:
- return self._daemon_listen_maddr
- @staticmethod
- async def send_raw_data(data: bytes, writer: asyncio.StreamWriter, *, chunk_size: int = 2 ** 16) -> None:
- writer.write(len(data).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER))
- data = memoryview(data)
- for offset in range(0, len(data), chunk_size):
- writer.write(data[offset : offset + chunk_size])
- await writer.drain()
- @staticmethod
- async def receive_raw_data(reader: asyncio.StreamReader) -> bytes:
- header = await reader.readexactly(P2P.HEADER_LEN)
- content_length = int.from_bytes(header, P2P.BYTEORDER)
- data = await reader.readexactly(content_length)
- return data
- TInputProtobuf = TypeVar("TInputProtobuf")
- TOutputProtobuf = TypeVar("TOutputProtobuf")
- @staticmethod
- async def send_protobuf(protobuf: Union[TOutputProtobuf, RPCError], writer: asyncio.StreamWriter) -> None:
- if isinstance(protobuf, RPCError):
- writer.write(P2P.ERROR_MARKER)
- else:
- writer.write(P2P.MESSAGE_MARKER)
- await P2P.send_raw_data(protobuf.SerializeToString(), writer)
- @staticmethod
- async def receive_protobuf(
- input_protobuf_type: Message, reader: asyncio.StreamReader
- ) -> Tuple[Optional[TInputProtobuf], Optional[RPCError]]:
- msg_type = await reader.readexactly(1)
- if msg_type == P2P.MESSAGE_MARKER:
- protobuf = input_protobuf_type()
- protobuf.ParseFromString(await P2P.receive_raw_data(reader))
- return protobuf, None
- elif msg_type == P2P.ERROR_MARKER:
- protobuf = RPCError()
- protobuf.ParseFromString(await P2P.receive_raw_data(reader))
- return None, protobuf
- else:
- raise TypeError("Invalid Protobuf message type")
- TInputStream = AsyncIterator[TInputProtobuf]
- TOutputStream = AsyncIterator[TOutputProtobuf]
- async def _add_protobuf_stream_handler(
- self,
- name: str,
- handler: Callable[[TInputStream, P2PContext], TOutputStream],
- input_protobuf_type: Message,
- max_prefetch: int = 5,
- ) -> None:
- """
- :param max_prefetch: Maximum number of items to prefetch from the request stream.
- ``max_prefetch <= 0`` means unlimited.
- :note: Since the cancel messages are sent via the input stream,
- they will not be received while the prefetch buffer is full.
- """
- async def _handle_stream(
- stream_info: StreamInfo, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
- ) -> None:
- context = P2PContext(
- handle_name=name,
- local_id=self.peer_id,
- remote_id=stream_info.peer_id,
- )
- requests = asyncio.Queue(max_prefetch)
- async def _read_stream() -> P2P.TInputStream:
- while True:
- request = await requests.get()
- if request is None:
- break
- yield request
- async def _process_stream() -> None:
- try:
- async for response in handler(_read_stream(), context):
- await P2P.send_protobuf(response, writer)
- except Exception as e:
- logger.warning("Exception while processing stream and sending responses:", exc_info=True)
- # Sometimes `e` is a connection error, so we won't be able to report the error to the caller
- with suppress(Exception):
- await P2P.send_protobuf(RPCError(message=str(e)), writer)
- with closing(writer):
- processing_task = asyncio.create_task(_process_stream())
- try:
- while True:
- receive_task = asyncio.create_task(P2P.receive_protobuf(input_protobuf_type, reader))
- await asyncio.wait({processing_task, receive_task}, return_when=asyncio.FIRST_COMPLETED)
- if processing_task.done():
- receive_task.cancel()
- return
- if receive_task.done():
- try:
- request, _ = await receive_task
- except asyncio.IncompleteReadError: # Connection is closed (the client cancelled or died)
- return
- await requests.put(request) # `request` is None for the end-of-stream message
- except Exception:
- logger.warning("Exception while receiving requests:", exc_info=True)
- finally:
- processing_task.cancel()
- await self.add_binary_stream_handler(name, _handle_stream)
- async def _iterate_protobuf_stream_handler(
- self, peer_id: PeerID, name: str, requests: TInputStream, output_protobuf_type: Message
- ) -> TOutputStream:
- _, reader, writer = await self.call_binary_stream_handler(peer_id, name)
- async def _write_to_stream() -> None:
- async for request in requests:
- await P2P.send_protobuf(request, writer)
- await P2P.send_protobuf(P2P.END_OF_STREAM, writer)
- with closing(writer):
- writing_task = asyncio.create_task(_write_to_stream())
- try:
- while True:
- try:
- response, err = await P2P.receive_protobuf(output_protobuf_type, reader)
- except asyncio.IncompleteReadError: # Connection is closed
- break
- if err is not None:
- raise P2PHandlerError(f"Failed to call handler `{name}` at {peer_id}: {err.message}")
- yield response
- await writing_task
- finally:
- writing_task.cancel()
- async def add_protobuf_handler(
- self,
- name: str,
- handler: Callable[
- [Union[TInputProtobuf, TInputStream], P2PContext], Union[Awaitable[TOutputProtobuf], TOutputStream]
- ],
- input_protobuf_type: Message,
- *,
- stream_input: bool = False,
- stream_output: bool = False,
- ) -> None:
- """
- :param stream_input: If True, assume ``handler`` to take ``TInputStream``
- (not just ``TInputProtobuf``) as input.
- :param stream_output: If True, assume ``handler`` to return ``TOutputStream``
- """
- if not stream_input and not stream_output:
- await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
- return
- async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
- input = requests if stream_input else await asingle(requests)
- output = handler(input, context)
- if isinstance(output, AsyncIterableABC):
- async for item in output:
- yield item
- else:
- yield await output
- await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
- # only registers request-response handlers
- async def _add_protobuf_unary_handler(
- self,
- handle_name: str,
- handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
- input_protobuf_type: Message,
- ) -> None:
- """
- Register a request-response (unary) handler. Unary requests and responses
- are sent through persistent multiplexed connections to the daemon for the
- sake of reducing the number of open files.
- :param handle_name: name of the handler (protocol id)
- :param handler: function handling the unary requests
- :param input_protobuf_type: protobuf type of the request
- """
- async def _unary_handler(request: bytes, remote_id: PeerID) -> bytes:
- input_serialized = input_protobuf_type.FromString(request)
- context = P2PContext(
- handle_name=handle_name,
- local_id=self.peer_id,
- remote_id=remote_id,
- )
- response = await handler(input_serialized, context)
- return response.SerializeToString()
- await self._client.add_unary_handler(handle_name, _unary_handler)
- async def call_protobuf_handler(
- self,
- peer_id: PeerID,
- name: str,
- input: Union[TInputProtobuf, TInputStream],
- output_protobuf_type: Message,
- ) -> Awaitable[TOutputProtobuf]:
- if not isinstance(input, AsyncIterableABC):
- return await self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
- requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
- responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
- return await asingle(responses)
- async def _call_unary_protobuf_handler(
- self,
- peer_id: PeerID,
- handle_name: str,
- input: TInputProtobuf,
- output_protobuf_type: Message,
- ) -> Awaitable[TOutputProtobuf]:
- serialized_input = input.SerializeToString()
- response = await self._client.call_unary_handler(peer_id, handle_name, serialized_input)
- return output_protobuf_type().FromString(response)
- def iterate_protobuf_handler(
- self,
- peer_id: PeerID,
- name: str,
- input: Union[TInputProtobuf, TInputStream],
- output_protobuf_type: Message,
- ) -> TOutputStream:
- requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
- return self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
- def _start_listening(self) -> None:
- async def listen() -> None:
- async with self._client.listen():
- await asyncio.Future() # Wait until this task will be cancelled in _terminate()
- self._listen_task = asyncio.create_task(listen())
- async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
- if self._listen_task is None:
- self._start_listening()
- await self._client.stream_handler(name, handler)
- async def call_binary_stream_handler(
- self, peer_id: PeerID, handler_name: str
- ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
- return await self._client.stream_open(peer_id, (handler_name,))
- def __del__(self):
- self._terminate()
- @property
- def is_alive(self) -> bool:
- return self._alive
- async def shutdown(self) -> None:
- self._terminate()
- if self._child is not None:
- await self._child.wait()
- def _terminate(self) -> None:
- if self._listen_task is not None:
- self._listen_task.cancel()
- if self._reader_task is not None:
- self._reader_task.cancel()
- self._alive = False
- if self._child is not None and self._child.returncode is None:
- self._child.terminate()
- logger.debug(f"Terminated p2pd with id = {self.peer_id}")
- with suppress(FileNotFoundError):
- os.remove(self._daemon_listen_maddr["unix"])
- with suppress(FileNotFoundError):
- os.remove(self._client_listen_maddr["unix"])
- @staticmethod
- def _make_process_args(*args, **kwargs) -> List[str]:
- proc_args = []
- proc_args.extend(str(entry) for entry in args)
- proc_args.extend(
- f"-{key}={P2P._convert_process_arg_type(value)}" if value is not None else f"-{key}"
- for key, value in kwargs.items()
- )
- return proc_args
- @staticmethod
- def _convert_process_arg_type(val: Any) -> Any:
- if isinstance(val, bool):
- return int(val)
- return val
- @staticmethod
- def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
- return ",".join(str(addr) for addr in maddrs)
- async def _read_outputs(self, ready: asyncio.Future) -> None:
- last_line = None
- while True:
- line = await self._child.stdout.readline()
- if not line: # Stream closed
- break
- last_line = line.rstrip().decode(errors="ignore")
- if last_line.startswith("Peer ID:"):
- ready.set_result(None)
- if not ready.done():
- ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
- class P2PDaemonError(RuntimeError):
- pass
|