소스 검색

Speed up P2P client creation (#343)

This PR resolves #300 and speeds up the creation of P2P clients by reading the daemon's stdout instead of pinging it repeatedly.

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Denis Mazur 4 년 전
부모
커밋
ecdc0965ca

+ 2 - 2
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 
 import hivemind
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
     )
     assert expert_cls in layers.name_to_block
-    port = port or find_open_port()
+    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()

+ 2 - 2
hivemind/moe/server/__init__.py

@@ -27,7 +27,7 @@ from hivemind.moe.server.layers import (
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
+from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
 logger = get_logger(__name__)
 
@@ -68,7 +68,7 @@ class Server(threading.Thread):
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=find_open_port())
+            listen_on = replace_port(listen_on, new_port=get_free_port())
         self.listen_on, self.port = listen_on, get_port(listen_on)
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
+from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 40 - 48
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from importlib.resources import path
-from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
 
 from multiaddr import Multiaddr
@@ -68,8 +67,8 @@ class P2P:
         self.peer_id = None
         self._child = None
         self._alive = False
+        self._reader_task = None
         self._listen_task = None
-        self._server_stopped = asyncio.Event()
 
     @classmethod
     async def create(
@@ -90,9 +89,7 @@ class P2P:
         use_relay_discovery: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
-        quiet: bool = True,
-        ping_n_attempts: int = 5,
-        ping_delay: float = 0.4,
+        startup_timeout: float = 15,
     ) -> "P2P":
         """
         Start a new p2pd process and connect to it.
@@ -113,10 +110,7 @@ class P2P:
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
-        :param quiet: make the daemon process quiet
-        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
-        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
-          (in particular, wait for ``ping_delay`` seconds before the first attempt)
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         """
 
@@ -157,37 +151,26 @@ class P2P:
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             b=need_bootstrap,
-            q=quiet,
             **process_kwargs,
         )
 
-        self._child = Popen(args=proc_args, encoding="utf8")
+        self._child = await asyncio.subprocess.create_subprocess_exec(
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+        )
         self._alive = True
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
-        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+        ready = asyncio.Future()
+        self._reader_task = asyncio.create_task(self._read_outputs(ready))
+        try:
+            await asyncio.wait_for(ready, startup_timeout)
+        except asyncio.TimeoutError:
+            await self.shutdown()
+            raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        await self._ping_daemon()
         return self
 
-    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
-        for try_number in range(ping_n_attempts):
-            await asyncio.sleep(ping_delay * (2 ** try_number))
-
-            if self._child.poll() is not None:  # Process died
-                break
-
-            try:
-                await self._ping_daemon()
-                break
-            except Exception as e:
-                if try_number == ping_n_attempts - 1:
-                    logger.exception("Failed to ping p2pd that has just started")
-                    await self.shutdown()
-                    raise
-
-        if self._child.returncode is not None:
-            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
-
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
@@ -437,20 +420,10 @@ class P2P:
     def _start_listening(self) -> None:
         async def listen() -> None:
             async with self._client.listen():
-                await self._server_stopped.wait()
+                await asyncio.Future()  # Wait until this task will be cancelled in _terminate()
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def _stop_listening(self) -> None:
-        if self._listen_task is not None:
-            self._server_stopped.set()
-            self._listen_task.cancel()
-            try:
-                await self._listen_task
-            except asyncio.CancelledError:
-                self._listen_task = None
-                self._server_stopped.clear()
-
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
             self._start_listening()
@@ -469,14 +442,19 @@ class P2P:
         return self._alive
 
     async def shutdown(self) -> None:
-        await self._stop_listening()
-        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+        self._terminate()
+        if self._child is not None:
+            await self._child.wait()
 
     def _terminate(self) -> None:
+        if self._listen_task is not None:
+            self._listen_task.cancel()
+        if self._reader_task is not None:
+            self._reader_task.cancel()
+
         self._alive = False
-        if self._child is not None and self._child.poll() is None:
+        if self._child is not None and self._child.returncode is None:
             self._child.terminate()
-            self._child.wait()
             logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
             with suppress(FileNotFoundError):
@@ -504,8 +482,22 @@ class P2P:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
         return ",".join(str(addr) for addr in maddrs)
 
+    async def _read_outputs(self, ready: asyncio.Future) -> None:
+        last_line = None
+        while True:
+            line = await self._child.stdout.readline()
+            if not line:  # Stream closed
+                break
+            last_line = line.rstrip().decode(errors="ignore")
+
+            if last_line.startswith("Peer ID:"):
+                ready.set_result(None)
+
+        if not ready.done():
+            ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
+
 
-class P2PInterruptedError(Exception):
+class P2PDaemonError(RuntimeError):
     pass
 
 

+ 1 - 1
hivemind/utils/networking.py

@@ -30,7 +30,7 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     try:
         with closing(socket.socket(*params)) as sock:

+ 13 - 1
tests/test_p2p_daemon.py

@@ -9,8 +9,9 @@ import numpy as np
 import pytest
 from multiaddr import Multiaddr
 
-from hivemind.p2p import P2P, P2PHandlerError
+from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2
+from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 
 
@@ -33,6 +34,17 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_startup_error_message():
+    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+        await P2P.create(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
+        )
+
+    with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
+        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+
+
 @pytest.mark.parametrize(
     "host_maddrs",
     [

+ 1 - 1
tests/test_utils/dht_swarms.py

@@ -18,7 +18,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs):
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
 
-    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10, **kwargs))
+    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs))
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
 
     info_queue.put((node.node_id, node.peer_id, maddrs))

+ 4 - 4
tests/test_utils/p2p_daemon.py

@@ -10,7 +10,7 @@ from typing import NamedTuple
 from multiaddr import Multiaddr, protocols
 from pkg_resources import resource_filename
 
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
 TIMEOUT_DURATION = 30  # seconds
@@ -57,7 +57,7 @@ class Daemon:
 
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
@@ -129,8 +129,8 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
 
 @asynccontextmanager
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
-    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
     async with _make_p2pd_pair(
         control_maddr=control_maddr,
         listen_maddr=listen_maddr,