ソースを参照

Extract _get_handler_name method in Servicer

Aleksandr Borzunov 4 年 前
コミット
b0173f6816
1 ファイル変更34 行追加29 行削除
  1. 34 29
      hivemind/p2p/servicer.py

+ 34 - 29
hivemind/p2p/servicer.py

@@ -10,7 +10,6 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
 @dataclass
 class RPCHandler:
     method_name: str
-    handle_name: str
     request_type: type
     response_type: type
     stream_input: bool
@@ -45,38 +44,18 @@ class ServicerBase:
     _rpc_handlers: Optional[List[RPCHandler]] = None
     _stub_type: Optional[Type[StubBase]] = None
 
-    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
-        self._collect_rpc_handlers()
-
-        servicer = self if wrapper is None else wrapper
-        for handler in self._rpc_handlers:
-            await p2p.add_protobuf_handler(
-                handler.handle_name,
-                getattr(servicer, handler.method_name),
-                handler.request_type,
-                stream_input=handler.stream_input,
-            )
-
-    @classmethod
-    def get_stub(cls, p2p: P2P, peer: PeerID) -> StubBase:
-        cls._collect_rpc_handlers()
-        return cls._stub_type(p2p, peer)
-
     @classmethod
     def _collect_rpc_handlers(cls) -> None:
         if cls._rpc_handlers is not None:
             return
 
-        class_name = cls.__name__
         cls._rpc_handlers = []
         for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
             if method_name.startswith("rpc_"):
-                handle_name = f"{class_name}.{method_name}"
-
                 spec = inspect.getfullargspec(method)
                 if len(spec.args) < 3:
                     raise ValueError(
-                        f"{handle_name} is expected to at least three positional arguments "
+                        f"{method_name} is expected to at least three positional arguments "
                         f"(self: TServicer, request: TInputProtobuf, context: hivemind.p2p.P2PContext)"
                     )
                 request_arg = spec.args[1]
@@ -86,7 +65,7 @@ class ServicerBase:
                     response_type = hints["return"]
                 except KeyError:
                     raise ValueError(
-                        f"{handle_name} is expected to have type annotations "
+                        f"{method_name} is expected to have type annotations "
                         f"like `dht_pb2.FindRequest` or `AsyncIterator[dht_pb2.FindRequest]` "
                         f"for the `{request_arg}` parameter and the return value"
                     )
@@ -94,17 +73,17 @@ class ServicerBase:
                 response_type, stream_output = cls._strip_iterator_hint(response_type)
 
                 cls._rpc_handlers.append(
-                    RPCHandler(method_name, handle_name, request_type, response_type, stream_input, stream_output)
+                    RPCHandler(method_name, request_type, response_type, stream_input, stream_output)
                 )
 
         cls._stub_type = type(
-            f"{class_name}Stub",
+            f"{cls.__name__}Stub",
             (StubBase,),
             {handler.method_name: cls._make_rpc_caller(handler) for handler in cls._rpc_handlers},
         )
 
-    @staticmethod
-    def _make_rpc_caller(handler: RPCHandler):
+    @classmethod
+    def _make_rpc_caller(cls, handler: RPCHandler):
         input_type = AsyncIterator[handler.request_type] if handler.stream_input else handler.request_type
 
         # This method will be added to a new Stub type (a subclass of StubBase)
@@ -118,7 +97,7 @@ class ServicerBase:
 
                 return self._p2p.iterate_protobuf_handler(
                     self._peer,
-                    handler.handle_name,
+                    cls._get_handle_name(handler.method_name),
                     input,
                     handler.response_type,
                 )
@@ -129,13 +108,39 @@ class ServicerBase:
                 self: StubBase, input: input_type, timeout: Optional[float] = None
             ) -> handler.response_type:
                 return await asyncio.wait_for(
-                    self._p2p.call_protobuf_handler(self._peer, handler.handle_name, input, handler.response_type),
+                    self._p2p.call_protobuf_handler(
+                        self._peer,
+                        cls._get_handle_name(handler.method_name),
+                        input,
+                        handler.response_type,
+                    ),
                     timeout=timeout,
                 )
 
         caller.__name__ = handler.method_name
         return caller
 
+    async def add_p2p_handlers(self, p2p: P2P, wrapper: Any = None) -> None:
+        self._collect_rpc_handlers()
+
+        servicer = self if wrapper is None else wrapper
+        for handler in self._rpc_handlers:
+            await p2p.add_protobuf_handler(
+                self._get_handle_name(handler.method_name),
+                getattr(servicer, handler.method_name),
+                handler.request_type,
+                stream_input=handler.stream_input,
+            )
+
+    @classmethod
+    def get_stub(cls, p2p: P2P, peer: PeerID) -> StubBase:
+        cls._collect_rpc_handlers()
+        return cls._stub_type(p2p, peer)
+
+    @classmethod
+    def _get_handle_name(cls, method_name: str) -> str:
+        return f"{cls.__name__}.{method_name}"
+
     @staticmethod
     def _strip_iterator_hint(hint: type) -> Tuple[type, bool]:
         if hasattr(hint, "_name") and hint._name in ("AsyncIterator", "AsyncIterable"):