فهرست منبع

add unary handler support to P2P

Denis Mazur 4 سال پیش
والد
کامیت
013938f65e
3فایلهای تغییر یافته به همراه34 افزوده شده و 22 حذف شده
  1. 6 0
      hivemind/p2p/p2p_daemon.py
  2. 21 21
      hivemind/p2p/p2p_daemon_bindings/control.py
  3. 7 1
      hivemind/p2p/p2p_daemon_bindings/p2pclient.py

+ 6 - 0
hivemind/p2p/p2p_daemon.py

@@ -169,6 +169,12 @@ class P2P:
 
         return self
 
+    async def add_unary_handler(self, proto: str, handler: p2pclient.TUnaryHandler):
+        return await self._client.add_unary_handler(proto, handler)
+
+    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self._client.unary_call(peer_id, proto, data)
+
     async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
         for try_number in range(ping_n_attempts):
             await asyncio.sleep(ping_delay * (2 ** try_number))

+ 21 - 21
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -87,6 +87,22 @@ class ControlClient:
         self.pending_messages: asyncio.Queue[p2pd_pb.Request] = asyncio.Queue()
         self.pending_calls: Dict[CallID, asyncio.Future] = {}
 
+    @asynccontextmanager
+    async def listen(self) -> AsyncIterator["ControlClient"]:
+        proto_code = parse_conn_protocol(self.listen_maddr)
+        if proto_code == protocols.P_UNIX:
+            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
+            server = await asyncio.start_unix_server(self._handler, path=listen_path)
+        elif proto_code == protocols.P_IP4:
+            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
+            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
+            server = await asyncio.start_server(self._handler, port=port, host=host)
+        else:
+            raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
+
+        async with server:
+            yield self
+
     async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
         while True:
             resp: p2pd_pb.Response = p2pd_pb.Response()  # type: ignore
@@ -106,6 +122,11 @@ class ControlClient:
             elif resp.requestHandling:
                 asyncio.create_task(self._handle_persistent_request(resp.requestHandling))
                 pass
+            
+    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
+        while True:
+            msg = await self.pending_messages.get()
+            await write_pbmsg(writer, msg)
 
     async def _handle_persistent_request(self, request):
         assert request.proto in self.unary_handlers
@@ -123,11 +144,6 @@ class ControlClient:
         await self.pending_messages.put(
             p2pd_pb.Request(type=p2pd_pb.Request.UNARY_RESPONSE, response=response))
 
-    async def _write_to_persistent_conn(self, writer: asyncio.StreamWriter):
-        while True:
-            msg = await self.pending_messages.get()
-            await write_pbmsg(writer, msg)
-
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -140,22 +156,6 @@ class ControlClient:
             raise DispatchFailure(e)
         await handler(stream_info, reader, writer)
 
-    @asynccontextmanager
-    async def listen(self) -> AsyncIterator["ControlClient"]:
-        proto_code = parse_conn_protocol(self.listen_maddr)
-        if proto_code == protocols.P_UNIX:
-            listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX)
-            server = await asyncio.start_unix_server(self._handler, path=listen_path)
-        elif proto_code == protocols.P_IP4:
-            host = self.listen_maddr.value_for_protocol(protocols.P_IP4)
-            port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP))
-            server = await asyncio.start_server(self._handler, port=port, host=host)
-        else:
-            raise ValueError(f"Protocol not supported: {protocols.protocol_with_code(proto_code)}")
-
-        async with server:
-            yield self
-
     async def _ensure_persistent_conn(self):
         if not self._pers_conn_open:
             reader, writer = await self.daemon_connector.open_persistent_connection()

+ 7 - 1
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -10,7 +10,7 @@ from typing import AsyncIterator, Iterable, Sequence, Tuple
 
 from multiaddr import Multiaddr
 
-from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler
+from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, StreamHandler, TUnaryHandler
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 
 
@@ -30,6 +30,12 @@ class Client:
         async with self.control.listen():
             yield self
 
+    async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
+        await self.control.add_unary_handler(proto, handler)
+        
+    async def unary_call(self, peer_id: PeerID, proto: str, data: bytes) -> bytes:
+        return await self.control.unary_call(peer_id, proto, data)
+
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         """
         Get current node peer id and list of addresses