Browse Source

Add daemon termination support (#348)

Denis Mazur 4 years ago
parent
commit
56cfb4e312

+ 1 - 1
hivemind/dht/protocol.py

@@ -81,7 +81,7 @@ class DHTProtocol(ServicerBase):
 
 
     def __init__(self, *, _initialized_with_create=False):
     def __init__(self, *, _initialized_with_create=False):
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
         """Internal init method. Please use DHTProtocol.create coroutine to spawn new protocol instances"""
-        assert _initialized_with_create, " Please use DHTProtocol.create coroutine to spawn new protocol instances "
+        assert _initialized_with_create, "Please use DHTProtocol.create coroutine to spawn new protocol instances"
         super().__init__()
         super().__init__()
 
 
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:
     def get_stub(self, peer: PeerID) -> AuthRPCWrapper:

+ 6 - 2
hivemind/p2p/p2p_daemon.py

@@ -91,6 +91,7 @@ class P2P:
         use_auto_relay: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
         relay_hop_limit: int = 0,
         startup_timeout: float = 15,
         startup_timeout: float = 15,
+        idle_timeout: float = 0,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -112,6 +113,8 @@ class P2P:
         :param use_auto_relay: enables autorelay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
         :param relay_hop_limit: sets the hop limit for hop relays
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
+        :param idle_timeout: kill daemon if client has been idle for a given number of
+                             seconds before opening persistent streams
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -151,6 +154,7 @@ class P2P:
             relayDiscovery=use_relay_discovery,
             relayDiscovery=use_relay_discovery,
             autoRelay=use_auto_relay,
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
+            idleTimeout=f"{idle_timeout}s",
             b=need_bootstrap,
             b=need_bootstrap,
             **process_kwargs,
             **process_kwargs,
         )
         )
@@ -168,7 +172,7 @@ class P2P:
             await self.shutdown()
             await self.shutdown()
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
             raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self
 
 
@@ -190,7 +194,7 @@ class P2P:
         self._daemon_listen_maddr = daemon_listen_maddr
         self._daemon_listen_maddr = daemon_listen_maddr
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
         self._client_listen_maddr = Multiaddr(cls._UNIX_SOCKET_PREFIX + f"p2pclient-{socket_uid}.sock")
 
 
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        self._client = await p2pclient.Client.create(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
         await self._ping_daemon()
         await self._ping_daemon()
         return self
         return self

+ 25 - 13
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -74,8 +74,14 @@ class ControlClient:
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
     DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock"
 
 
     def __init__(
     def __init__(
-        self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR)
+        self,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        *,
+        _initialized_with_create=False,
     ) -> None:
     ) -> None:
+        assert _initialized_with_create, "Please use ControlClient.create coroutine to spawn new control instances"
+
         self.listen_maddr = listen_maddr
         self.listen_maddr = listen_maddr
         self.daemon_connector = daemon_connector
         self.daemon_connector = daemon_connector
         self.handlers: Dict[str, StreamHandler] = {}
         self.handlers: Dict[str, StreamHandler] = {}
@@ -83,7 +89,6 @@ class ControlClient:
         self._is_persistent_conn_open: bool = False
         self._is_persistent_conn_open: bool = False
         self.unary_handlers: Dict[str, TUnaryHandler] = {}
         self.unary_handlers: Dict[str, TUnaryHandler] = {}
 
 
-        self._ensure_conn_lock = asyncio.Lock()
         self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
         self._pending_messages: asyncio.Queue[p2pd_pb.PersistentConnectionRequest] = asyncio.Queue()
         self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self._pending_calls: Dict[CallID, asyncio.Future[bytes]] = {}
         self._handler_tasks: Dict[CallID, asyncio.Task] = {}
         self._handler_tasks: Dict[CallID, asyncio.Task] = {}
@@ -91,6 +96,20 @@ class ControlClient:
         self._read_task: Optional[asyncio.Task] = None
         self._read_task: Optional[asyncio.Task] = None
         self._write_task: Optional[asyncio.Task] = None
         self._write_task: Optional[asyncio.Task] = None
 
 
+    @classmethod
+    async def create(
+        cls,
+        daemon_connector: DaemonConnector,
+        listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR),
+        use_persistent_conn: bool = True,
+    ) -> "ControlClient":
+        control = cls(daemon_connector, listen_maddr, _initialized_with_create=True)
+
+        if use_persistent_conn:
+            await control._ensure_persistent_conn()
+
+        return control
+
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
     async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         pb_stream_info = p2pd_pb.StreamInfo()  # type: ignore
         await read_pbmsg_safe(reader, pb_stream_info)
         await read_pbmsg_safe(reader, pb_stream_info)
@@ -184,19 +203,14 @@ class ControlClient:
         )
         )
 
 
     async def _ensure_persistent_conn(self):
     async def _ensure_persistent_conn(self):
-        if not self._is_persistent_conn_open:
-            async with self._ensure_conn_lock:
-                if not self._is_persistent_conn_open:
-                    reader, writer = await self.daemon_connector.open_persistent_connection()
+        reader, writer = await self.daemon_connector.open_persistent_connection()
 
 
-                    self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
-                    self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
+        self._read_task = asyncio.create_task(self._read_from_persistent_conn(reader))
+        self._write_task = asyncio.create_task(self._write_to_persistent_conn(writer))
 
 
-                    self._is_persistent_conn_open = True
+        self._is_persistent_conn_open = True
 
 
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
     async def add_unary_handler(self, proto: str, handler: TUnaryHandler):
-        await self._ensure_persistent_conn()
-
         call_id = uuid.uuid4()
         call_id = uuid.uuid4()
 
 
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
         add_unary_handler_req = p2pd_pb.AddUnaryHandlerRequest(proto=proto)
@@ -220,8 +234,6 @@ class ControlClient:
             callUnary=call_unary_req,
             callUnary=call_unary_req,
         )
         )
 
 
-        await self._ensure_persistent_conn()
-
         try:
         try:
             self._pending_calls[call_id] = asyncio.Future()
             self._pending_calls[call_id] = asyncio.Future()
             await self._pending_messages.put(req)
             await self._pending_messages.put(req)

+ 10 - 2
hivemind/p2p/p2p_daemon_bindings/p2pclient.py

@@ -17,9 +17,17 @@ from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, St
 class Client:
 class Client:
     control: ControlClient
     control: ControlClient
 
 
-    def __init__(self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> None:
+    def __init__(self) -> None:
+        self.control = None
+
+    @classmethod
+    async def create(cls, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None) -> "Client":
+        client = cls()
+
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
         daemon_connector = DaemonConnector(control_maddr=control_maddr)
-        self.control = ControlClient(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+        client.control = await ControlClient.create(daemon_connector=daemon_connector, listen_maddr=listen_maddr)
+
+        return client
 
 
     @asynccontextmanager
     @asynccontextmanager
     async def listen(self) -> AsyncIterator["Client"]:
     async def listen(self) -> AsyncIterator["Client"]:

+ 12 - 5
tests/test_p2p_daemon_bindings.py

@@ -193,7 +193,8 @@ def test_parse_conn_protocol_invalid(maddr_str):
 
 
 
 
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("control_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_client_ctor_control_maddr(control_maddr_str):
+@pytest.mark.asyncio
+async def test_client_ctor_control_maddr(control_maddr_str):
     c = DaemonConnector(Multiaddr(control_maddr_str))
     c = DaemonConnector(Multiaddr(control_maddr_str))
     assert c.control_maddr == Multiaddr(control_maddr_str)
     assert c.control_maddr == Multiaddr(control_maddr_str)
 
 
@@ -204,13 +205,19 @@ def test_client_ctor_default_control_maddr():
 
 
 
 
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
 @pytest.mark.parametrize("listen_maddr_str", ("/unix/123", "/ip4/127.0.0.1/tcp/6666"))
-def test_control_client_ctor_listen_maddr(listen_maddr_str):
-    c = ControlClient(daemon_connector=DaemonConnector(), listen_maddr=Multiaddr(listen_maddr_str))
+@pytest.mark.asyncio
+async def test_control_client_ctor_listen_maddr(listen_maddr_str):
+    c = await ControlClient.create(
+        daemon_connector=DaemonConnector(),
+        listen_maddr=Multiaddr(listen_maddr_str),
+        use_persistent_conn=False,
+    )
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
     assert c.listen_maddr == Multiaddr(listen_maddr_str)
 
 
 
 
-def test_control_client_ctor_default_listen_maddr():
-    c = ControlClient(daemon_connector=DaemonConnector())
+@pytest.mark.asyncio
+async def test_control_client_ctor_default_listen_maddr():
+    c = await ControlClient.create(daemon_connector=DaemonConnector(), use_persistent_conn=False)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
     assert c.listen_maddr == Multiaddr(ControlClient.DEFAULT_LISTEN_MADDR)
 
 
 
 

+ 1 - 1
tests/test_utils/p2p_daemon.py

@@ -160,7 +160,7 @@ async def _make_p2pd_pair(
     )
     )
     # wait for daemon ready
     # wait for daemon ready
     await p2pd.wait_until_ready()
     await p2pd.wait_until_ready()
-    client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr)
+    client = await Client.create(control_maddr=control_maddr, listen_maddr=listen_maddr)
     try:
     try:
         async with client.listen():
         async with client.listen():
             yield DaemonTuple(daemon=p2pd, client=client)
             yield DaemonTuple(daemon=p2pd, client=client)