Browse Source

Propagate startup errors from DHT and averager processes (#347)

Currently, startup failure of DHT and/or averager (e.g. when `initial_peers` are unreachable) leads to a deadlock in the parent process. The exception message is printed to stdout but the parent process does not stop, and this is not nice for tests and end users.

This PR makes startup errors to be propagated to the parent process.
Alexander Borzunov 4 years ago
parent
commit
d893a0aef9
4 changed files with 71 additions and 37 deletions
  1. 32 23
      hivemind/averaging/averager.py
  2. 21 12
      hivemind/dht/__init__.py
  3. 16 0
      tests/test_dht.py
  4. 2 2
      tests/test_utils/dht_swarms.py

+ 32 - 23
hivemind/averaging/averager.py

@@ -170,7 +170,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             allow_state_sharing = not client_mode and not auxiliary
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
         self.allow_state_sharing = allow_state_sharing
 
 
-        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+        self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
         background_fetcher = threading.Thread(
             daemon=True,
             daemon=True,
@@ -214,25 +214,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
             async def _run():
             async def _run():
-                self._p2p = await self.dht.replicate_p2p()
-                if not self.client_mode:
-                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                else:
-                    logger.debug(f"The averager is running in client mode.")
-
-                self._matchmaking = Matchmaking(
-                    self._p2p,
-                    self.schema_hash,
-                    self.dht,
-                    client_mode=self.client_mode,
-                    **self.matchmaking_kwargs,
-                )
-                if not self.client_mode:
-                    asyncio.create_task(self._declare_for_download_periodically())
-
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
-                self.ready.set()
+                try:
+                    self._p2p = await self.dht.replicate_p2p()
+                    if not self.client_mode:
+                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                    else:
+                        logger.debug(f"The averager is running in client mode.")
+
+                    self._matchmaking = Matchmaking(
+                        self._p2p,
+                        self.schema_hash,
+                        self.dht,
+                        client_mode=self.client_mode,
+                        **self.matchmaking_kwargs,
+                    )
+                    if not self.client_mode:
+                        asyncio.create_task(self._declare_for_download_periodically())
+
+                    self._pending_group_assembled = asyncio.Event()
+                    self._pending_group_assembled.set()
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
 
                 while True:
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -243,14 +249,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
 
             loop.run_until_complete(_run())
             loop.run_until_complete(_run())
 
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         """
         Starts averager in a background process. if await_ready, this method will wait until background dht
         Starts averager in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         is ready to process incoming requests or for :timeout: seconds max.
         """
         """
         self.start()
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """Shut down the averager process"""
         """Shut down the averager process"""

+ 21 - 12
hivemind/dht/__init__.py

@@ -86,7 +86,7 @@ class DHT(mp.Process):
         self._record_validator = CompositeValidator(record_validators)
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
         self.shutdown_timeout = shutdown_timeout
-        self.ready = mp.Event()
+        self._ready = MPFuture()
         self.daemon = daemon
         self.daemon = daemon
 
 
         # These values will be fetched from the child process when requested
         # These values will be fetched from the child process when requested
@@ -104,13 +104,19 @@ class DHT(mp.Process):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
 
             async def _run():
             async def _run():
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    **self.kwargs,
-                )
-                self.ready.set()
+                try:
+                    self._node = await DHTNode.create(
+                        initial_peers=self.initial_peers,
+                        num_workers=self.num_workers,
+                        record_validator=self._record_validator,
+                        **self.kwargs,
+                    )
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
 
                 while True:
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -122,14 +128,17 @@ class DHT(mp.Process):
             coro = _run()
             coro = _run()
             loop.run_until_complete(coro)
             loop.run_until_complete(coro)
 
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         """
         Starts DHT in a background process. if await_ready, this method will wait until background dht
         Starts DHT in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         is ready to process incoming requests or for :timeout: seconds max.
         """
         """
         self.start()
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
 
     def shutdown(self) -> None:
     def shutdown(self) -> None:
         """Shut down a running dht process"""
         """Shut down a running dht process"""
@@ -252,7 +261,7 @@ class DHT(mp.Process):
                 future.set_exception(e)
                 future.set_exception(e)
 
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self.ready.is_set():
+        if not self._ready.done():
             raise RuntimeError(
             raise RuntimeError(
                 "Can't append new validators before the DHT process has started. "
                 "Can't append new validators before the DHT process has started. "
                 "Consider adding them to the initial list via DHT.__init__(record_validators=...)"
                 "Consider adding them to the initial list via DHT.__init__(record_validators=...)"

+ 16 - 0
tests/test_dht.py

@@ -1,4 +1,5 @@
 import asyncio
 import asyncio
+import concurrent.futures
 import random
 import random
 import time
 import time
 
 
@@ -6,10 +7,25 @@ import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
 import hivemind
 import hivemind
+from hivemind.utils.networking import get_free_port
 
 
 from test_utils.dht_swarms import launch_dht_instances
 from test_utils.dht_swarms import launch_dht_instances
 
 
 
 
+@pytest.mark.asyncio
+async def test_startup_error():
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+        hivemind.DHT(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
+            start=True,
+        )
+
+    dht = hivemind.DHT(start=True, await_ready=False)
+    with pytest.raises(concurrent.futures.TimeoutError):
+        dht.wait_until_ready(timeout=0.1)
+    dht.shutdown()
+
+
 @pytest.mark.forked
 @pytest.mark.forked
 def test_get_store(n_peers=10):
 def test_get_store(n_peers=10):
     peers = launch_dht_instances(n_peers)
     peers = launch_dht_instances(n_peers)

+ 2 - 2
tests/test_utils/dht_swarms.py

@@ -94,7 +94,7 @@ def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
     initial_peers = dhts[0].get_visible_maddrs()
     initial_peers = dhts[0].get_visible_maddrs()
 
 
     dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
     dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
-    for instance in dhts[1:]:
-        instance.ready.wait()
+    for process in dhts[1:]:
+        process.wait_until_ready()
 
 
     return dhts
     return dhts