123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import asyncio
- import inspect
- from dataclasses import dataclass
- from typing import Any, AsyncIterator, List, Optional, Tuple, Type, get_type_hints
- from hivemind.p2p.p2p_daemon import P2P
- from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
- @dataclass
- class RPCHandler:
- method_name: str
- request_type: type
- response_type: type
- stream_input: bool
- stream_output: bool
- class StubBase:
- """
- Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
- Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
- adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
- """
- def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
- self._p2p = p2p
- self._peer = peer
- self._namespace = namespace
- class ServicerBase:
- """
- Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
- - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
- other peers to call them. It uses type annotations for the ``request`` parameter and the return value
- to infer protobufs the methods operate with.
- - ``get_stub(self, p2p, peer)`` creates a stub with all rpc_* methods. Calls to the stub methods are translated
- to calls to the remote peer.
- """
- _rpc_handlers: Optional[List[RPCHandler]] = None
- _stub_type: Optional[Type[StubBase]] = None
- @classmethod
- def _collect_rpc_handlers(cls) -> None:
- if cls._rpc_handlers is not None:
- return
- cls._rpc_handlers = []
- for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
- if method_name.startswith("rpc_"):
- spec = inspect.getfullargspec(method)
- if len(spec.args) < 3:
- raise ValueError(
- f"{method_name} is expected to at least three positional arguments "
- f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
- )
- request_arg = spec.args[1]
- hints = get_type_hints(method)
- try:
- request_type = hints[request_arg]
- response_type = hints["return"]
- except KeyError:
- raise ValueError(
- f"{method_name} is expected to have type annotations "
- f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
- f"for the `{request_arg}` parameter and the return value"
- )
- request_type, stream_input = cls._strip_iterator_hint(request_type)
- response_type, stream_output = cls._strip_iterator_hint(response_type)
- cls._rpc_handlers.append(
- RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
- )
- cls._stub_type = type(
- f"{cls.__name__}Stub",
- (StubBase,),
- {handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
- )
- @classmethod
- def _make_rpc_caller(cls, handler: RPCHandler):
- input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
- # This method will be added to a new Stub type (a subclass of StubBase)
- if handler.stream_output:
- def caller(
- self: StubBase, input: input_type, timeout: None = None
- ) -> AsyncIterator[handler.response_type]:
- if timeout is not None:
- raise ValueError("Timeouts for handlers returning streams are not supported")
- return self._p2p.iterate_protobuf_handler(
- self._peer,
- cls._get_handle_name(self._namespace, handler.method_name),
- input,
- handler.response_type,
- )
- else:
- async def caller(
- self: StubBase, input: input_type, timeout: Optional[float] = None
- ) -> handler.response_type:
- return await asyncio.wait_for(
- self._p2p.call_protobuf_handler(
- self._peer,
- cls._get_handle_name(self._namespace, handler.method_name),
- input,
- handler.response_type,
- ),
- timeout=timeout,
- )
- caller.__name__ = handler.method_name
- return caller
- async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
- self._collect_rpc_handlers()
- servicer = self if wrapper is None else wrapper
- for handler in self._rpc_handlers:
- await p2p.add_protobuf_handler(
- self._get_handle_name(namespace, handler.method_name),
- getattr(servicer, handler.method_name),
- handler.request_type,
- stream_input=handler.stream_input,
- )
- @classmethod
- def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
- cls._collect_rpc_handlers()
- return cls._stub_type(p2p, peer, namespace)
- @classmethod
- def _get_handle_name(cls, namespace: Optional[str], method_name: str) -> str:
- handle_name = f"{cls.__name__}.{method_name}"
- if namespace is not None:
- handle_name = f"{namespace}::{handle_name}"
- return handle_name
- @staticmethod
- def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
- if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):
- return hint.__args__[0], True
- return hint, False
|