ソースを参照

reimplement protobuf handlers

Denis Mazur 4 年 前
コミット
044b2321fe
2 ファイル変更48 行追加18 行削除
  1. 47 17
      hivemind/p2p/p2p_daemon.py
  2. 1 1
      hivemind/p2p/p2p_daemon_bindings/control.py

+ 47 - 17
hivemind/p2p/p2p_daemon.py

@@ -28,7 +28,6 @@ class P2PContext(object):
     handle_name: str
     local_id: PeerID
     remote_id: PeerID = None
-    remote_maddr: Multiaddr = None
 
 
 class P2P:
@@ -169,11 +168,11 @@ class P2P:
 
         return self
 
-    async def add_unary_handler(self, proto: str, handler: p2pclient.TUnaryHandler):
-        return await self._client.add_unary_handler(proto, handler)
+    async def add_unary_handler(self, handle_name: str, handler: p2pclient.TUnaryHandler):
+        return await self._client.add_unary_handler(handle_name, handler)
 
-    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
-        return await self._client.unary_call(peer_id, proto, data)
+    async def unary_call(self, peer_id: PeerID, handle_name: str, data: bytes) -> bytes:
+        return await self._client.unary_call(peer_id, handle_name, data)
 
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
@@ -323,7 +322,6 @@ class P2P:
                 handle_name=name,
                 local_id=self.id,
                 remote_id=stream_info.peer_id,
-                remote_maddr=stream_info.addr,
             )
             requests = asyncio.Queue(max_prefetch)
 
@@ -408,16 +406,12 @@ class P2P:
                              (not just ``TInputProtobuf``) as input.
         """
 
-        async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
-            if stream_input:
-                input = requests
-            else:
-                count = 0
-                async for input in requests:
-                    count += 1
-                if count != 1:
-                    raise ValueError(f"Got {count} requests for handler {name} instead of one")
+        if not stream_input:
+            await self._add_protobuf_unary_handler(name, handler, input_protobuf_type)
+            return
 
+        async def _stream_handler(requests: P2P.TInputStream, context: P2PContext) -> P2P.TOutputStream:
+            input = requests
             output = handler(input, context)
 
             if isinstance(output, AsyncIterableABC):
@@ -428,6 +422,27 @@ class P2P:
 
         await self._add_protobuf_stream_handler(name, _stream_handler, input_protobuf_type)
 
+    # only registers request-response handlers
+    async def _add_protobuf_unary_handler(
+            self,
+            handle_name: str,
+            handler: Callable[[TInputProtobuf, P2PContext], Awaitable[TOutputProtobuf]],
+            input_protobuf_type: type,
+    ) -> None:
+
+        async def _unary_handler(request: bytes) -> bytes:
+            input_serialized = input_protobuf_type().FromString(request)
+            context = P2PContext(
+                handle_name=handle_name,
+                local_id=self.id,
+                # TODO: add remote id
+            )
+
+            response = await handler(input_serialized, context)
+            return response.SerializeToString()
+
+        await self.add_unary_handler(handle_name, _unary_handler)
+
     async def call_protobuf_handler(
         self,
         peer_id: PeerID,
@@ -435,8 +450,11 @@ class P2P:
         input: Union[TInputProtobuf, TInputStream],
         output_protobuf_type: type,
     ) -> Awaitable[TOutputProtobuf]:
-        requests = input if isinstance(input, AsyncIterableABC) else aiter(input)
-        responses = self._iterate_protobuf_stream_handler(peer_id, name, requests, output_protobuf_type)
+
+        if not isinstance(input, AsyncIterableABC):
+            return self._call_unary_protobuf_handler(peer_id, name, input, output_protobuf_type)
+
+        responses = self._iterate_protobuf_stream_handler(peer_id, name, input, output_protobuf_type)
 
         count = 0
         async for response in responses:
@@ -445,6 +463,18 @@ class P2P:
             raise ValueError(f"Got {count} responses from handler {name} instead of one")
         return response
 
+    async def _call_unary_protobuf_handler(
+            self,
+            peer_id: PeerID,
+            handle_name: str,
+            input: TInputProtobuf,
+            output_protobuf_type: type,
+    ) -> Awaitable[TOutputProtobuf]:
+        serialized_input = input.SerializeToString()
+        response = await self.unary_call(peer_id, handle_name, serialized_input)
+        return output_protobuf_type().FromString(response)
+
+
     def iterate_protobuf_handler(
         self,
         peer_id: PeerID,

+ 1 - 1
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -66,7 +66,7 @@ class DaemonConnector:
         return reader, writer
 
 
-TUnaryHandler = Callable[[bytes], bytes]
+TUnaryHandler = Callable[[bytes], Awaitable[bytes]]
 CallID = uuid.UUID