|
@@ -10,7 +10,6 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
|
|
|
@dataclass
|
|
|
class RPCHandler:
|
|
|
method_name: str
|
|
|
- handle_name: str
|
|
|
request_type: type
|
|
|
response_type: type
|
|
|
stream_input: bool
|
|
@@ -45,38 +44,18 @@ class ServicerBase:
|
|
|
_rpc_handlers: Optional[List[RPCHandler]] = None
|
|
|
_stub_type: Optional[Type[StubBase]] = None
|
|
|
|
|
|
- async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
|
|
|
- self._collect_rpc_handlers()
|
|
|
-
|
|
|
- servicer = self if wrapper is None else wrapper
|
|
|
- for handler in self._rpc_handlers:
|
|
|
- await p2p.add_protobuf_handler(
|
|
|
- handler.handle_name,
|
|
|
- getattr(servicer, handler.method_name),
|
|
|
- handler.request_type,
|
|
|
- stream_input=handler.stream_input,
|
|
|
- )
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_stub(cls, p2p: P2P, peer: PeerID) -> StubBase:
|
|
|
- cls._collect_rpc_handlers()
|
|
|
- return cls._stub_type(p2p, peer)
|
|
|
-
|
|
|
@classmethod
|
|
|
def _collect_rpc_handlers(cls) -> None:
|
|
|
if cls._rpc_handlers is not None:
|
|
|
return
|
|
|
|
|
|
- class_name = cls.__name__
|
|
|
cls._rpc_handlers = []
|
|
|
for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
|
|
if method_name.startswith("rpc_"):
|
|
|
- handle_name = f"{class_name}.{method_name}"
|
|
|
-
|
|
|
spec = inspect.getfullargspec(method)
|
|
|
if len(spec.args) < 3:
|
|
|
raise ValueError(
|
|
|
- f"{handle_name} is expected to at least three positional arguments "
|
|
|
+ f"{method_name} is expected to at least three positional arguments "
|
|
|
f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
|
|
|
)
|
|
|
request_arg = spec.args[1]
|
|
@@ -86,7 +65,7 @@ class ServicerBase:
|
|
|
response_type = hints["return"]
|
|
|
except KeyError:
|
|
|
raise ValueError(
|
|
|
- f"{handle_name} is expected to have type annotations "
|
|
|
+ f"{method_name} is expected to have type annotations "
|
|
|
f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
|
|
|
f"for the `{request_arg}` parameter and the return value"
|
|
|
)
|
|
@@ -94,17 +73,17 @@ class ServicerBase:
|
|
|
response_type, stream_output = cls._strip_iterator_hint(response_type)
|
|
|
|
|
|
cls._rpc_handlers.append(
|
|
|
- RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
|
|
|
+ RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
|
|
|
)
|
|
|
|
|
|
cls._stub_type = type(
|
|
|
- f"{class_name}Stub",
|
|
|
+ f"{cls.__name__}Stub",
|
|
|
(StubBase,),
|
|
|
{handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
|
|
|
)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _make_rpc_caller(handler: RPCHandler):
|
|
|
+ @classmethod
|
|
|
+ def _make_rpc_caller(cls, handler: RPCHandler):
|
|
|
input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
|
|
|
|
|
|
# This method will be added to a new Stub type (a subclass of StubBase)
|
|
@@ -118,7 +97,7 @@ class ServicerBase:
|
|
|
|
|
|
return self._p2p.iterate_protobuf_handler(
|
|
|
self._peer,
|
|
|
- handler.handle_name,
|
|
|
+ cls._get_handle_name(handler.method_name),
|
|
|
input,
|
|
|
handler.response_type,
|
|
|
)
|
|
@@ -129,13 +108,39 @@ class ServicerBase:
|
|
|
self: StubBase, input: input_type, timeout: Optional[float] = None
|
|
|
) -> handler.response_type:
|
|
|
return await asyncio.wait_for(
|
|
|
- self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
|
|
|
+ self._p2p.call_protobuf_handler(
|
|
|
+ self._peer,
|
|
|
+ cls._get_handle_name(handler.method_name),
|
|
|
+ input,
|
|
|
+ handler.response_type,
|
|
|
+ ),
|
|
|
timeout=timeout,
|
|
|
)
|
|
|
|
|
|
caller.__name__ = handler.method_name
|
|
|
return caller
|
|
|
|
|
|
+ async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
|
|
|
+ self._collect_rpc_handlers()
|
|
|
+
|
|
|
+ servicer = self if wrapper is None else wrapper
|
|
|
+ for handler in self._rpc_handlers:
|
|
|
+ await p2p.add_protobuf_handler(
|
|
|
+ self._get_handle_name(handler.method_name),
|
|
|
+ getattr(servicer, handler.method_name),
|
|
|
+ handler.request_type,
|
|
|
+ stream_input=handler.stream_input,
|
|
|
+ )
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_stub(cls, p2p: P2P, peer: PeerID) -> StubBase:
|
|
|
+ cls._collect_rpc_handlers()
|
|
|
+ return cls._stub_type(p2p, peer)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def _get_handle_name(cls, method_name: str) -> str:
|
|
|
+ return f"{cls.__name__}.{method_name}"
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
|
|
|
if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):
|