123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- """
- Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings
- Licence: MIT
- Author: Kevin Mai-Husan Chia
- """
- import asyncio
- import uuid
- from contextlib import asynccontextmanager, closing
- from typing import AsyncIterator, Awaitable, Callable, Dict, Iterable, Optional, Sequence, Tuple
- from multiaddr import Multiaddr, protocols
- from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
- from hivemind.p2p.p2p_daemon_bindings.utils import DispatchFailure, raise_if_failed, read_pbmsg_safe, write_pbmsg
- from hivemind.proto import p2pd_pb2 as p2pd_pb
- from hivemind.utils.logging import get_logger
- StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]]
- SUPPORT_CONN_PROTOCOLS = (
- protocols.P_IP4,
- # protocols.P_IP6,
- protocols.P_UNIX,
- )
- SUPPORTED_PROTOS = (protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS)
- logger = get_logger(__name__)
- def parse_conn_protocol(maddr: Multiaddr) -> int:
- proto_codes = set(proto.code for proto in maddr.protocols())
- proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS)
- if len(proto_cand) != 1:
- raise ValueError(
- f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}" f", maddr={maddr}"
- )
- return tuple(proto_cand)[0]
- class DaemonConnector:
- DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock"
- def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None:
- self.control_maddr = control_maddr
- self.proto_code = parse_conn_protocol(self.control_maddr)
- async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
- if self.proto_code == protocols.P_UNIX:
- control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX)
- return await asyncio.open_unix_connection(control_path)
- elif self.proto_code == protocols.P_IP4:
- host = self.control_maddr.value_for_protocol(protocols.P_IP4)
- port = int(self.control_maddr.value_for_protocol(protocols.P_TCP))
- return await asyncio.open_connection(host, port)
- else:
- raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}")
- async def open_persistent_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter):
- """
- Open connection to daemon and upgrade it to a persistent one
- """
- reader, writer = await self.open_connection()
- req = p2pd_pb.Request(type=p2pd_pb.Request.PERSISTENT_CONN_UPGRADE)
- await write_pbmsg(writer, req)
- return reader, writer
- TUnaryHandler = Callable[[bytes, PeerID], Awaitable[bytes]]
- CallID = uuid.UUID
- class ControlClient:
- DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
- def __init__(
- self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
- ) -> None:
- self.listen_maddr = listen_maddr
- self.daemon_connector = daemon_connector
- self.handlers: Dict[str, StreamHandler] = {}
- self._is_persistent_conn_open: bool = False
- self.unary_handlers: Dict[str, TUnaryHandler] = {}
- self._ensure_conn_lock = asyncio.Lock()
- self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
- self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
- self._handler_tasks: Dict[CallID, asyncio.Task] = {}
- self._read_task: Optional[asyncio.Task] = None
- self._write_task: Optional[asyncio.Task] = None
- async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
- pb_stream_info = p2pd_pb.StreamInfo() # type: ignore
- await read_pbmsg_safe(reader, pb_stream_info)
- stream_info = StreamInfo.from_protobuf(pb_stream_info)
- try:
- handler = self.handlers[stream_info.proto]
- except KeyError as e:
- # should never enter here... daemon should reject the stream for us.
- writer.close()
- raise DispatchFailure(e)
- await handler(stream_info, reader, writer)
- @asynccontextmanager
- async def listen(self) -> AsyncIterator["ControlClient"]:
- proto_code = parse_conn_protocol(self.listen_maddr)
- if proto_code == protocols.P_UNIX:
- listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
- server = await asyncio.start_unix_server(self._handler, path=listen_path)
- elif proto_code == protocols.P_IP4:
- host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
- port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
- server = await asyncio.start_server(self._handler, port=port, host=host)
- else:
- raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
- try:
- async with server:
- yield self
- finally:
- if self._read_task is not None:
- self._read_task.cancel()
- if self._write_task is not None:
- self._write_task.cancel()
- async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
- while True:
- resp = p2pd_pb.PersistentConnectionResponse()
- await read_pbmsg_safe(reader, resp)
- call_id = uuid.UUID(bytes=resp.callId)
- if resp.HasField("callUnaryResponse"):
- if call_id in self._pending_calls and resp.callUnaryResponse.HasField("response"):
- self._pending_calls[call_id].set_result(resp.callUnaryResponse.response)
- elif call_id in self._pending_calls and resp.callUnaryResponse.HasField("error"):
- remote_exc = P2PHandlerError(resp.callUnaryResponse.error.decode(errors="ignore"))
- self._pending_calls[call_id].set_exception(remote_exc)
- else:
- logger.debug("received unexpected unary call")
- elif resp.HasField("requestHandling"):
- handler_task = asyncio.create_task(self._handle_persistent_request(call_id, resp.requestHandling))
- self._handler_tasks[call_id] = handler_task
- elif call_id in self._handler_tasks and resp.HasField("cancel"):
- self._handler_tasks[call_id].cancel()
- async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
- with closing(writer):
- while True:
- msg = await self._pending_messages.get()
- await write_pbmsg(writer, msg)
- async def _handle_persistent_request(self, call_id: uuid.UUID, request: p2pd_pb.CallUnaryRequest):
- if request.proto not in self.unary_handlers:
- logger.warning(f"Protocol {request.proto} not supported")
- return
- try:
- remote_id = PeerID(request.peer)
- response_payload: bytes = await self.unary_handlers[request.proto](request.data, remote_id)
- response = p2pd_pb.CallUnaryResponse(response=response_payload)
- except Exception as e:
- response = p2pd_pb.CallUnaryResponse(error=repr(e).encode())
- await self._pending_messages.put(
- p2pd_pb.PersistentConnectionRequest(
- callId=call_id.bytes,
- unaryResponse=response,
- )
- )
- self._handler_tasks.pop(call_id)
- async def _cancel_unary_call(self, call_id: uuid.UUID):
- await self._pending_messages.put(
- p2pd_pb.PersistentConnectionRequest(
- callId=call_id.bytes,
- cancel=p2pd_pb.Cancel(),
- ),
- )
- async def _ensure_persistent_conn(self):
- if not self._is_persistent_conn_open:
- async with self._ensure_conn_lock:
- if not self._is_persistent_conn_open:
- reader, writer = await self.daemon_connector.open_persistent_connection()
- self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
- self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
- self._is_persistent_conn_open = True
- async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
- await self._ensure_persistent_conn()
- call_id = uuid.uuid4()
- add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
- req = p2pd_pb.PersistentConnectionRequest(callId=call_id.bytes, addUnaryHandler=add_unary_handler_req)
- if self.unary_handlers.get(proto):
- raise ValueError(f"Handler for protocol {proto} already assigned")
- self.unary_handlers[proto] = handler
- await self._pending_messages.put(req)
- async def call_unary_handler(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
- call_id = uuid.uuid4()
- call_unary_req = p2pd_pb.CallUnaryRequest(
- peer=peer_id.to_bytes(),
- proto=proto,
- data=data,
- )
- req = p2pd_pb.PersistentConnectionRequest(
- callId=call_id.bytes,
- callUnary=call_unary_req,
- )
- await self._ensure_persistent_conn()
- try:
- self._pending_calls[call_id] = asyncio.Future()
- await self._pending_messages.put(req)
- return await self._pending_calls[call_id]
- except asyncio.CancelledError:
- await self._cancel_unary_call(call_id)
- raise
- finally:
- self._pending_calls.pop(call_id, None)
- async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
- reader, writer = await self.daemon_connector.open_connection()
- req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY)
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- writer.close()
- raise_if_failed(resp)
- peer_id_bytes = resp.identify.id
- maddrs_bytes = resp.identify.addrs
- maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes)
- peer_id = PeerID(peer_id_bytes)
- return peer_id, maddrs
- async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
- reader, writer = await self.daemon_connector.open_connection()
- maddrs_bytes = [i.to_bytes() for i in maddrs]
- connect_req = p2pd_pb.ConnectRequest(peer=peer_id.to_bytes(), addrs=maddrs_bytes)
- req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req)
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- writer.close()
- raise_if_failed(resp)
- async def list_peers(self) -> Tuple[PeerInfo, ...]:
- req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS)
- reader, writer = await self.daemon_connector.open_connection()
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- writer.close()
- raise_if_failed(resp)
- peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
- return peers
- async def disconnect(self, peer_id: PeerID) -> None:
- disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes())
- req = p2pd_pb.Request(type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req)
- reader, writer = await self.daemon_connector.open_connection()
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- writer.close()
- raise_if_failed(resp)
- async def stream_open(
- self, peer_id: PeerID, protocols: Sequence[str]
- ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]:
- reader, writer = await self.daemon_connector.open_connection()
- stream_open_req = p2pd_pb.StreamOpenRequest(peer=peer_id.to_bytes(), proto=list(protocols))
- req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req)
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- raise_if_failed(resp)
- pb_stream_info = resp.streamInfo
- stream_info = StreamInfo.from_protobuf(pb_stream_info)
- return stream_info, reader, writer
- async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None:
- reader, writer = await self.daemon_connector.open_connection()
- listen_path_maddr_bytes = self.listen_maddr.to_bytes()
- stream_handler_req = p2pd_pb.StreamHandlerRequest(addr=listen_path_maddr_bytes, proto=[proto])
- req = p2pd_pb.Request(type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req)
- await write_pbmsg(writer, req)
- resp = p2pd_pb.Response() # type: ignore
- await read_pbmsg_safe(reader, resp)
- writer.close()
- raise_if_failed(resp)
- # if success, add the handler to the dict
- self.handlers[proto] = handler_cb
- class P2PHandlerError(Exception):
- """
- Raised if remote handled a request with an exception
- """
|