Browse Source

Propagate startup errors from DHT and averager processes (#347)

Currently, startup failure of DHT and/or averager (e.g. when `initial_peers` are unreachable) leads to a deadlock in the parent process. The exception message is printed to stdout but the parent process does not stop, and this is not nice for tests and end users.

This PR makes startup errors to be propagated to the parent process.
Alexander Borzunov 4 năm trước cách đây
mục cha
commit
d893a0aef9
4 tập tin đã thay đổi với 71 bổ sung37 xóa
  1. 32 23
      hivemind/averaging/averager.py
  2. 21 12
      hivemind/dht/__init__.py
  3. 16 0
      tests/test_dht.py
  4. 2 2
      tests/test_utils/dht_swarms.py

+ 32 - 23
hivemind/averaging/averager.py

@@ -170,7 +170,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
 
-        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+        self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
             daemon=True,
@@ -214,25 +214,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                self._p2p = await self.dht.replicate_p2p()
-                if not self.client_mode:
-                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                else:
-                    logger.debug(f"The averager is running in client mode.")
-
-                self._matchmaking = Matchmaking(
-                    self._p2p,
-                    self.schema_hash,
-                    self.dht,
-                    client_mode=self.client_mode,
-                    **self.matchmaking_kwargs,
-                )
-                if not self.client_mode:
-                    asyncio.create_task(self._declare_for_download_periodically())
-
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
-                self.ready.set()
+                try:
+                    self._p2p = await self.dht.replicate_p2p()
+                    if not self.client_mode:
+                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                    else:
+                        logger.debug(f"The averager is running in client mode.")
+
+                    self._matchmaking = Matchmaking(
+                        self._p2p,
+                        self.schema_hash,
+                        self.dht,
+                        client_mode=self.client_mode,
+                        **self.matchmaking_kwargs,
+                    )
+                    if not self.client_mode:
+                        asyncio.create_task(self._declare_for_download_periodically())
+
+                    self._pending_group_assembled = asyncio.Event()
+                    self._pending_group_assembled.set()
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -243,14 +249,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             loop.run_until_complete(_run())
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         Starts averager in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
     def shutdown(self) -> None:
         """Shut down the averager process"""

+ 21 - 12
hivemind/dht/__init__.py

@@ -86,7 +86,7 @@ class DHT(mp.Process):
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
-        self.ready = mp.Event()
+        self._ready = MPFuture()
         self.daemon = daemon
 
         # These values will be fetched from the child process when requested
@@ -104,13 +104,19 @@ class DHT(mp.Process):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    **self.kwargs,
-                )
-                self.ready.set()
+                try:
+                    self._node = await DHTNode.create(
+                        initial_peers=self.initial_peers,
+                        num_workers=self.num_workers,
+                        record_validator=self._record_validator,
+                        **self.kwargs,
+                    )
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -122,14 +128,17 @@ class DHT(mp.Process):
             coro = _run()
             loop.run_until_complete(coro)
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         Starts DHT in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
     def shutdown(self) -> None:
         """Shut down a running dht process"""
@@ -252,7 +261,7 @@ class DHT(mp.Process):
                 future.set_exception(e)
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self.ready.is_set():
+        if not self._ready.done():
             raise RuntimeError(
                 "Can't append new validators before the DHT process has started. "
                 "Consider adding them to the initial list via DHT.__init__(record_validators=...)"

+ 16 - 0
tests/test_dht.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 import random
 import time
 
@@ -6,10 +7,25 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
+from hivemind.utils.networking import get_free_port
 
 from test_utils.dht_swarms import launch_dht_instances
 
 
+@pytest.mark.asyncio
+async def test_startup_error():
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+        hivemind.DHT(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
+            start=True,
+        )
+
+    dht = hivemind.DHT(start=True, await_ready=False)
+    with pytest.raises(concurrent.futures.TimeoutError):
+        dht.wait_until_ready(timeout=0.1)
+    dht.shutdown()
+
+
 @pytest.mark.forked
 def test_get_store(n_peers=10):
     peers = launch_dht_instances(n_peers)

+ 2 - 2
tests/test_utils/dht_swarms.py

@@ -94,7 +94,7 @@ def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
     initial_peers = dhts[0].get_visible_maddrs()
 
     dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
-    for instance in dhts[1:]:
-        instance.ready.wait()
+    for process in dhts[1:]:
+        process.wait_until_ready()
 
     return dhts