servicer.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import asyncio
  2. import inspect
  3. from dataclasses import dataclass
  4. from typing import Any, AsyncIterator, List, Optional, Tuple, Type, get_type_hints
  5. from hivemind.p2p.p2p_daemon import P2P
  6. from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
  7. @dataclass
  8. class RPCHandler:
  9. method_name: str
  10. request_type: type
  11. response_type: type
  12. stream_input: bool
  13. stream_output: bool
  14. class StubBase:
  15. """
  16. Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
  17. Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
  18. adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
  19. """
  20. def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
  21. self._p2p = p2p
  22. self._peer = peer
  23. self._namespace = namespace
  24. class ServicerBase:
  25. """
  26. Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
  27. - ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
  28. other peers to call them. It uses type annotations for the ``request`` parameter and the return value
  29. to infer protobufs the methods operate with.
  30. - ``get_stub(self, p2p, peer)`` creates a stub with all rpc_* methods. Calls to the stub methods are translated
  31. to calls to the remote peer.
  32. """
  33. _rpc_handlers: Optional[List[RPCHandler]] = None
  34. _stub_type: Optional[Type[StubBase]] = None
  35. @classmethod
  36. def _collect_rpc_handlers(cls) -> None:
  37. if cls._rpc_handlers is not None:
  38. return
  39. cls._rpc_handlers = []
  40. for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
  41. if method_name.startswith("rpc_"):
  42. spec = inspect.getfullargspec(method)
  43. if len(spec.args) < 3:
  44. raise ValueError(
  45. f"{method_name} is expected to at least three positional arguments "
  46. f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
  47. )
  48. request_arg = spec.args[1]
  49. hints = get_type_hints(method)
  50. try:
  51. request_type = hints[request_arg]
  52. response_type = hints["return"]
  53. except KeyError:
  54. raise ValueError(
  55. f"{method_name} is expected to have type annotations "
  56. f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
  57. f"for the `{request_arg}` parameter and the return value"
  58. )
  59. request_type, stream_input = cls._strip_iterator_hint(request_type)
  60. response_type, stream_output = cls._strip_iterator_hint(response_type)
  61. cls._rpc_handlers.append(
  62. RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
  63. )
  64. cls._stub_type = type(
  65. f"{cls.__name__}Stub",
  66. (StubBase,),
  67. {handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
  68. )
  69. @classmethod
  70. def _make_rpc_caller(cls, handler: RPCHandler):
  71. input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
  72. # This method will be added to a new Stub type (a subclass of StubBase)
  73. if handler.stream_output:
  74. def caller(
  75. self: StubBase, input: input_type, timeout: None = None
  76. ) -> AsyncIterator[handler.response_type]:
  77. if timeout is not None:
  78. raise ValueError("Timeouts for handlers returning streams are not supported")
  79. return self._p2p.iterate_protobuf_handler(
  80. self._peer,
  81. cls._get_handle_name(self._namespace, handler.method_name),
  82. input,
  83. handler.response_type,
  84. )
  85. else:
  86. async def caller(
  87. self: StubBase, input: input_type, timeout: Optional[float] = None
  88. ) -> handler.response_type:
  89. return await asyncio.wait_for(
  90. self._p2p.call_protobuf_handler(
  91. self._peer,
  92. cls._get_handle_name(self._namespace, handler.method_name),
  93. input,
  94. handler.response_type,
  95. ),
  96. timeout=timeout,
  97. )
  98. caller.__name__ = handler.method_name
  99. return caller
  100. async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None, *, namespace: Optional[str] = None) -> None:
  101. self._collect_rpc_handlers()
  102. servicer = self if wrapper is None else wrapper
  103. for handler in self._rpc_handlers:
  104. await p2p.add_protobuf_handler(
  105. self._get_handle_name(namespace, handler.method_name),
  106. getattr(servicer, handler.method_name),
  107. handler.request_type,
  108. stream_input=handler.stream_input,
  109. )
  110. @classmethod
  111. def get_stub(cls, p2p: P2P, peer: PeerID, *, namespace: Optional[str] = None) -> StubBase:
  112. cls._collect_rpc_handlers()
  113. return cls._stub_type(p2p, peer, namespace)
  114. @classmethod
  115. def _get_handle_name(cls, namespace: Optional[str], method_name: str) -> str:
  116. handle_name = f"{cls.__name__}.{method_name}"
  117. if namespace is not None:
  118. handle_name = f"{namespace}::{handle_name}"
  119. return handle_name
  120. @staticmethod
  121. def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
  122. if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):
  123. return hint.__args__[0], True
  124. return hint, False