Browse Source

Merge branch 'master' into unary-handlers

Denis Mazur 4 năm trước cách đây
mục cha
commit
619931d5bd

+ 32 - 23
hivemind/averaging/averager.py

@@ -170,7 +170,7 @@ class DecentralizedAverager(mp.Process, ServicerBase):
             allow_state_sharing = not client_mode and not auxiliary
         self.allow_state_sharing = allow_state_sharing
 
-        self.ready = mp.Event()  # whether the averager process has started (and ready for incoming requests)
+        self._ready = MPFuture()
         # note: we create a background thread weakref and with daemon=True to ensure garbage collection
         background_fetcher = threading.Thread(
             daemon=True,
@@ -214,25 +214,31 @@ class DecentralizedAverager(mp.Process, ServicerBase):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                self._p2p = await self.dht.replicate_p2p()
-                if not self.client_mode:
-                    await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
-                else:
-                    logger.debug(f"The averager is running in client mode.")
-
-                self._matchmaking = Matchmaking(
-                    self._p2p,
-                    self.schema_hash,
-                    self.dht,
-                    client_mode=self.client_mode,
-                    **self.matchmaking_kwargs,
-                )
-                if not self.client_mode:
-                    asyncio.create_task(self._declare_for_download_periodically())
-
-                self._pending_group_assembled = asyncio.Event()
-                self._pending_group_assembled.set()
-                self.ready.set()
+                try:
+                    self._p2p = await self.dht.replicate_p2p()
+                    if not self.client_mode:
+                        await self.add_p2p_handlers(self._p2p, namespace=self.prefix)
+                    else:
+                        logger.debug(f"The averager is running in client mode.")
+
+                    self._matchmaking = Matchmaking(
+                        self._p2p,
+                        self.schema_hash,
+                        self.dht,
+                        client_mode=self.client_mode,
+                        **self.matchmaking_kwargs,
+                    )
+                    if not self.client_mode:
+                        asyncio.create_task(self._declare_for_download_periodically())
+
+                    self._pending_group_assembled = asyncio.Event()
+                    self._pending_group_assembled.set()
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -243,14 +249,17 @@ class DecentralizedAverager(mp.Process, ServicerBase):
 
             loop.run_until_complete(_run())
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         Starts averager in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"Server didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
     def shutdown(self) -> None:
         """Shut down the averager process"""

+ 21 - 12
hivemind/dht/__init__.py

@@ -86,7 +86,7 @@ class DHT(mp.Process):
         self._record_validator = CompositeValidator(record_validators)
         self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
         self.shutdown_timeout = shutdown_timeout
-        self.ready = mp.Event()
+        self._ready = MPFuture()
         self.daemon = daemon
 
         # These values will be fetched from the child process when requested
@@ -104,13 +104,19 @@ class DHT(mp.Process):
         with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
 
             async def _run():
-                self._node = await DHTNode.create(
-                    initial_peers=self.initial_peers,
-                    num_workers=self.num_workers,
-                    record_validator=self._record_validator,
-                    **self.kwargs,
-                )
-                self.ready.set()
+                try:
+                    self._node = await DHTNode.create(
+                        initial_peers=self.initial_peers,
+                        num_workers=self.num_workers,
+                        record_validator=self._record_validator,
+                        **self.kwargs,
+                    )
+                except Exception as e:
+                    # Loglevel is DEBUG since normally the exception is propagated to the caller
+                    logger.debug(e, exc_info=True)
+                    self._ready.set_exception(e)
+                    return
+                self._ready.set_result(None)
 
                 while True:
                     method, args, kwargs = await loop.run_in_executor(pipe_awaiter, self._inner_pipe.recv)
@@ -122,14 +128,17 @@ class DHT(mp.Process):
             coro = _run()
             loop.run_until_complete(coro)
 
-    def run_in_background(self, await_ready=True, timeout=None):
+    def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
         """
         Starts DHT in a background process. if await_ready, this method will wait until background dht
         is ready to process incoming requests or for :timeout: seconds max.
         """
         self.start()
-        if await_ready and not self.ready.wait(timeout=timeout):
-            raise TimeoutError(f"DHT didn't notify .ready in {timeout} seconds")
+        if await_ready:
+            self.wait_until_ready(timeout)
+
+    def wait_until_ready(self, timeout: Optional[float] = None) -> None:
+        self._ready.result(timeout=timeout)
 
     def shutdown(self) -> None:
         """Shut down a running dht process"""
@@ -252,7 +261,7 @@ class DHT(mp.Process):
                 future.set_exception(e)
 
     def add_validators(self, record_validators: Iterable[RecordValidatorBase]) -> None:
-        if not self.ready.is_set():
+        if not self._ready.done():
             raise RuntimeError(
                 "Can't append new validators before the DHT process has started. "
                 "Consider adding them to the initial list via DHT.__init__(record_validators=...)"

+ 5 - 3
hivemind/hivemind_cli/run_server.py

@@ -21,8 +21,9 @@ def main():
                         help="'localhost' for local connections only, '0.0.0.0' for ipv4 '[::]' for ipv6")
     parser.add_argument('--num_experts', type=int, default=None, required=False, help="The number of experts to serve")
     parser.add_argument('--expert_pattern', type=str, default=None, required=False,
-                        help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will sample random expert uids'
-                             ' between myexpert.0.0 and myexpert.255.1023 . Use either num_experts and this or expert_uids')
+                        help='all expert uids will follow this pattern, e.g. "myexpert.[0:256].[0:1024]" will'
+                             ' sample random expert uids between myexpert.0.0 and myexpert.255.1023 . Use either'
+                             ' num_experts and this or expert_uids')
     parser.add_argument('--expert_uids', type=str, nargs="*", default=None, required=False,
                         help="specify the exact list of expert uids to create. Use either this or num_experts"
                              " and expert_pattern, not both")
@@ -42,7 +43,8 @@ def main():
     parser.add_argument('--optimizer', type=str, default='adam', required=False, help='adam, sgd or none')
     parser.add_argument('--scheduler', type=str, choices=schedule_name_to_scheduler.keys(), default='none',
                         help='LR scheduler type to use')
-    parser.add_argument('--num_warmup_steps', type=int, required=False, help='The number of warmup steps for LR schedule')
+    parser.add_argument('--num_warmup_steps', type=int, required=False,
+                        help='The number of warmup steps for LR schedule')
     parser.add_argument('--num_total_steps', type=int, required=False, help='The total number of steps for LR schedule')
     parser.add_argument('--clip_grad_norm', type=float, required=False, help='Maximum gradient norm used for clipping')
 

+ 2 - 1
hivemind/moe/server/expert_backend.py

@@ -187,7 +187,8 @@ class ExpertBackend:
 
     def get_stats(self) -> Dict:
         """
-        Return current expert training statistics (number of updates, number of processed examples after last optimizer step)
+        Return current expert training statistics (number of updates, number of processed examples after
+        last optimizer step)
         """
         return {"updates": self.update_count, "examples_processed": self.examples_processed}
 

+ 16 - 0
tests/test_dht.py

@@ -1,4 +1,5 @@
 import asyncio
+import concurrent.futures
 import random
 import time
 
@@ -6,10 +7,25 @@ import pytest
 from multiaddr import Multiaddr
 
 import hivemind
+from hivemind.utils.networking import get_free_port
 
 from test_utils.dht_swarms import launch_dht_instances
 
 
+@pytest.mark.asyncio
+async def test_startup_error():
+    with pytest.raises(hivemind.p2p.P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+        hivemind.DHT(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"],
+            start=True,
+        )
+
+    dht = hivemind.DHT(start=True, await_ready=False)
+    with pytest.raises(concurrent.futures.TimeoutError):
+        dht.wait_until_ready(timeout=0.1)
+    dht.shutdown()
+
+
 @pytest.mark.forked
 def test_get_store(n_peers=10):
     peers = launch_dht_instances(n_peers)

+ 2 - 2
tests/test_utils/dht_swarms.py

@@ -94,7 +94,7 @@ def launch_dht_instances(n_peers: int, **kwargs) -> List[DHT]:
     initial_peers = dhts[0].get_visible_maddrs()
 
     dhts.extend(DHT(initial_peers=initial_peers, start=True, await_ready=False, **kwargs) for _ in range(n_peers - 1))
-    for instance in dhts[1:]:
-        instance.ready.wait()
+    for process in dhts[1:]:
+        process.wait_until_ready()
 
     return dhts