|
@@ -1,4 +1,5 @@
|
|
|
import asyncio
|
|
|
+import inspect
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
|
|
|
|
|
@@ -45,19 +46,24 @@ class ServicerBase:
|
|
|
class_name = self.__class__.__name__
|
|
|
|
|
|
self._rpc_handlers = []
|
|
|
- for method_name, method in self.__class__.__dict__.items():
|
|
|
- if method_name.startswith("rpc_") and callable(method):
|
|
|
+ for method_name, method in inspect.getmembers(self.__class__, predicate=inspect.isfunction):
|
|
|
+ if method_name.startswith("rpc_"):
|
|
|
handle_name = f"{class_name}.{method_name}"
|
|
|
|
|
|
+ spec = inspect.getfullargspec(method)
|
|
|
+ if len(spec.args) < 3:
|
|
|
+ raise ValueError(f"{handle_name} is expected to at least three positional arguments "
|
|
|
+ f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)")
|
|
|
+ request_arg = spec.args[1]
|
|
|
hints = get_type_hints(method)
|
|
|
try:
|
|
|
- request_type = hints["request"]
|
|
|
+ request_type = hints[request_arg]
|
|
|
response_type = hints["return"]
|
|
|
except KeyError:
|
|
|
raise ValueError(
|
|
|
f"{handle_name} is expected to have type annotations "
|
|
|
f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
|
|
|
- f"for the `request` parameter and the return value"
|
|
|
+ f"for the `{request_arg}` parameter and the return value"
|
|
|
)
|
|
|
request_type, stream_input = self._strip_iterator_hint(request_type)
|
|
|
response_type, stream_output = self._strip_iterator_hint(response_type)
|