Просмотр исходного кода

Support inheritance and arbitrary parameter names for rpc_* methods in ServicerBase

Aleksandr Borzunov 4 лет назад
Родитель
Сommit
27951768cb
2 измененных файлов с 14 добавлено и 8 удалено
  1. 10 4
      hivemind/p2p/servicer.py
  2. 4 4
      tests/test_p2p_servicer.py

+ 10 - 4
hivemind/p2p/servicer.py

@@ -1,4 +1,5 @@
 import asyncio
+import inspect
 from dataclasses import dataclass
 from typing import Any, AsyncIterator, Optional, Tuple, get_type_hints
 
@@ -45,19 +46,24 @@ class ServicerBase:
         class_name = self.__class__.__name__
 
         self._rpc_handlers = []
-        for method_name, method in self.__class__.__dict__.items():
-            if method_name.startswith("rpc_") and callable(method):
+        for method_name, method in inspect.getmembers(self.__class__, predicate=inspect.isfunction):
+            if method_name.startswith("rpc_"):
                 handle_name = f"{class_name}.{method_name}"
 
+                spec = inspect.getfullargspec(method)
+                if len(spec.args) < 3:
+                    raise ValueError(f"{handle_name} is expected to at least three positional arguments "
+                                     f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)")
+                request_arg = spec.args[1]
                 hints = get_type_hints(method)
                 try:
-                    request_type = hints["request"]
+                    request_type = hints[request_arg]
                     response_type = hints["return"]
                 except KeyError:
                     raise ValueError(
                         f"{handle_name} is expected to have type annotations "
                         f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
-                        f"for the `request` parameter and the return value"
+                        f"for the `{request_arg}` parameter and the return value"
                     )
                 request_type, stream_input = self._strip_iterator_hint(request_type)
                 response_type, stream_output = self._strip_iterator_hint(response_type)

+ 4 - 4
tests/test_p2p_servicer.py

@@ -33,9 +33,9 @@ async def test_unary_unary(server_client):
 @pytest.mark.asyncio
 async def test_stream_unary(server_client):
     class ExampleServicer(ServicerBase):
-        async def rpc_sum(self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
+        async def rpc_sum(self, numbers: AsyncIterator[test_pb2.TestRequest], _: P2PContext) -> test_pb2.TestResponse:
             result = 0
-            async for item in request:
+            async for item in numbers:
                 result += item.number
             return test_pb2.TestResponse(number=result)
 
@@ -76,9 +76,9 @@ async def test_unary_stream(server_client):
 async def test_stream_stream(server_client):
     class ExampleServicer(ServicerBase):
         async def rpc_powers(
-            self, request: AsyncIterator[test_pb2.TestRequest], _: P2PContext
+            self, stream: AsyncIterator[test_pb2.TestRequest], _: P2PContext
         ) -> AsyncIterator[test_pb2.TestResponse]:
-            async for item in request:
+            async for item in stream:
                 yield test_pb2.TestResponse(number=item.number ** 2)
                 yield test_pb2.TestResponse(number=item.number ** 3)