Browse Source

Speed up P2P client creation (#343)

This PR resolves #300 and speeds up the creation of P2P clients by reading the daemon's stdout instead of pinging it repeatedly.

Co-authored-by: Denis Mazur <denismazur8@gmail.com>
Co-authored-by: Alexander Borzunov <hxrussia@gmail.com>
Denis Mazur 4 năm trước cách đây
mục cha
commit
ecdc0965ca

+ 2 - 2
benchmarks/benchmark_throughput.py

@@ -7,7 +7,7 @@ import time
 import torch
 import torch
 
 
 import hivemind
 import hivemind
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.moe.server import layers
 from hivemind.moe.server import layers
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.limits import increase_file_limit
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -66,7 +66,7 @@ def benchmark_throughput(
         or torch.device(device) == torch.device("cpu")
         or torch.device(device) == torch.device("cpu")
     )
     )
     assert expert_cls in layers.name_to_block
     assert expert_cls in layers.name_to_block
-    port = port or find_open_port()
+    port = port or get_free_port()
     max_batch_size = max_batch_size or batch_size * 4
     max_batch_size = max_batch_size or batch_size * 4
     num_handlers = max(1, num_handlers or num_clients // 2)
     num_handlers = max(1, num_handlers or num_clients // 2)
     benchmarking_failed = mp.Event()
     benchmarking_failed = mp.Event()

+ 2 - 2
hivemind/moe/server/__init__.py

@@ -27,7 +27,7 @@ from hivemind.moe.server.layers import (
 )
 )
 from hivemind.moe.server.runtime import Runtime
 from hivemind.moe.server.runtime import Runtime
 from hivemind.proto.runtime_pb2 import CompressionType
 from hivemind.proto.runtime_pb2 import CompressionType
-from hivemind.utils import BatchTensorDescriptor, Endpoint, find_open_port, get_logger, get_port, replace_port
+from hivemind.utils import BatchTensorDescriptor, Endpoint, get_free_port, get_logger, get_port, replace_port
 
 
 logger = get_logger(__name__)
 logger = get_logger(__name__)
 
 
@@ -68,7 +68,7 @@ class Server(threading.Thread):
         super().__init__()
         super().__init__()
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         self.dht, self.experts, self.update_period = dht, expert_backends, update_period
         if get_port(listen_on) is None:
         if get_port(listen_on) is None:
-            listen_on = replace_port(listen_on, new_port=find_open_port())
+            listen_on = replace_port(listen_on, new_port=get_free_port())
         self.listen_on, self.port = listen_on, get_port(listen_on)
         self.listen_on, self.port = listen_on, get_port(listen_on)
 
 
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]
         self.conn_handlers = [ConnectionHandler(listen_on, self.experts) for _ in range(num_connection_handlers)]

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,3 +1,3 @@
-from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
+from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PDaemonError, P2PHandlerError
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo
 from hivemind.p2p.servicer import ServicerBase, StubBase
 from hivemind.p2p.servicer import ServicerBase, StubBase

+ 40 - 48
hivemind/p2p/p2p_daemon.py

@@ -5,7 +5,6 @@ from collections.abc import AsyncIterable as AsyncIterableABC
 from contextlib import closing, suppress
 from contextlib import closing, suppress
 from dataclasses import dataclass
 from dataclasses import dataclass
 from importlib.resources import path
 from importlib.resources import path
-from subprocess import Popen
 from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
 from typing import Any, AsyncIterator, Awaitable, Callable, List, Optional, Sequence, Tuple, TypeVar, Union
 
 
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
@@ -68,8 +67,8 @@ class P2P:
         self.peer_id = None
         self.peer_id = None
         self._child = None
         self._child = None
         self._alive = False
         self._alive = False
+        self._reader_task = None
         self._listen_task = None
         self._listen_task = None
-        self._server_stopped = asyncio.Event()
 
 
     @classmethod
     @classmethod
     async def create(
     async def create(
@@ -90,9 +89,7 @@ class P2P:
         use_relay_discovery: bool = False,
         use_relay_discovery: bool = False,
         use_auto_relay: bool = False,
         use_auto_relay: bool = False,
         relay_hop_limit: int = 0,
         relay_hop_limit: int = 0,
-        quiet: bool = True,
-        ping_n_attempts: int = 5,
-        ping_delay: float = 0.4,
+        startup_timeout: float = 15,
     ) -> "P2P":
     ) -> "P2P":
         """
         """
         Start a new p2pd process and connect to it.
         Start a new p2pd process and connect to it.
@@ -113,10 +110,7 @@ class P2P:
         :param use_relay_discovery: enables passive discovery for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
         :param relay_hop_limit: sets the hop limit for hop relays
-        :param quiet: make the daemon process quiet
-        :param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
-        :param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
-          (in particular, wait for ``ping_delay`` seconds before the first attempt)
+        :param startup_timeout: raise a P2PDaemonError if the daemon does not start in ``startup_timeout`` seconds
         :return: a wrapper for the p2p daemon
         :return: a wrapper for the p2p daemon
         """
         """
 
 
@@ -157,37 +151,26 @@ class P2P:
             autoRelay=use_auto_relay,
             autoRelay=use_auto_relay,
             relayHopLimit=relay_hop_limit,
             relayHopLimit=relay_hop_limit,
             b=need_bootstrap,
             b=need_bootstrap,
-            q=quiet,
             **process_kwargs,
             **process_kwargs,
         )
         )
 
 
-        self._child = Popen(args=proc_args, encoding="utf8")
+        self._child = await asyncio.subprocess.create_subprocess_exec(
+            *proc_args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
+        )
         self._alive = True
         self._alive = True
-        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
 
 
-        await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)
+        ready = asyncio.Future()
+        self._reader_task = asyncio.create_task(self._read_outputs(ready))
+        try:
+            await asyncio.wait_for(ready, startup_timeout)
+        except asyncio.TimeoutError:
+            await self.shutdown()
+            raise P2PDaemonError(f"Daemon failed to start in {startup_timeout:.1f} seconds")
 
 
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+        await self._ping_daemon()
         return self
         return self
 
 
-    async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
-        for try_number in range(ping_n_attempts):
-            await asyncio.sleep(ping_delay * (2 ** try_number))
-
-            if self._child.poll() is not None:  # Process died
-                break
-
-            try:
-                await self._ping_daemon()
-                break
-            except Exception as e:
-                if try_number == ping_n_attempts - 1:
-                    logger.exception("Failed to ping p2pd that has just started")
-                    await self.shutdown()
-                    raise
-
-        if self._child.returncode is not None:
-            raise RuntimeError(f"The p2p daemon has died with return code {self._child.returncode}")
-
     @classmethod
     @classmethod
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
     async def replicate(cls, daemon_listen_maddr: Multiaddr) -> "P2P":
         """
         """
@@ -437,20 +420,10 @@ class P2P:
     def _start_listening(self) -> None:
     def _start_listening(self) -> None:
         async def listen() -> None:
         async def listen() -> None:
             async with self._client.listen():
             async with self._client.listen():
-                await self._server_stopped.wait()
+                await asyncio.Future()  # Wait until this task will be cancelled in _terminate()
 
 
         self._listen_task = asyncio.create_task(listen())
         self._listen_task = asyncio.create_task(listen())
 
 
-    async def _stop_listening(self) -> None:
-        if self._listen_task is not None:
-            self._server_stopped.set()
-            self._listen_task.cancel()
-            try:
-                await self._listen_task
-            except asyncio.CancelledError:
-                self._listen_task = None
-                self._server_stopped.clear()
-
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
     async def add_binary_stream_handler(self, name: str, handler: p2pclient.StreamHandler) -> None:
         if self._listen_task is None:
         if self._listen_task is None:
             self._start_listening()
             self._start_listening()
@@ -469,14 +442,19 @@ class P2P:
         return self._alive
         return self._alive
 
 
     async def shutdown(self) -> None:
     async def shutdown(self) -> None:
-        await self._stop_listening()
-        await asyncio.get_event_loop().run_in_executor(None, self._terminate)
+        self._terminate()
+        if self._child is not None:
+            await self._child.wait()
 
 
     def _terminate(self) -> None:
     def _terminate(self) -> None:
+        if self._listen_task is not None:
+            self._listen_task.cancel()
+        if self._reader_task is not None:
+            self._reader_task.cancel()
+
         self._alive = False
         self._alive = False
-        if self._child is not None and self._child.poll() is None:
+        if self._child is not None and self._child.returncode is None:
             self._child.terminate()
             self._child.terminate()
-            self._child.wait()
             logger.debug(f"Terminated p2pd with id = {self.peer_id}")
             logger.debug(f"Terminated p2pd with id = {self.peer_id}")
 
 
             with suppress(FileNotFoundError):
             with suppress(FileNotFoundError):
@@ -504,8 +482,22 @@ class P2P:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
     def _maddrs_to_str(maddrs: List[Multiaddr]) -> str:
         return ",".join(str(addr) for addr in maddrs)
         return ",".join(str(addr) for addr in maddrs)
 
 
+    async def _read_outputs(self, ready: asyncio.Future) -> None:
+        last_line = None
+        while True:
+            line = await self._child.stdout.readline()
+            if not line:  # Stream closed
+                break
+            last_line = line.rstrip().decode(errors="ignore")
+
+            if last_line.startswith("Peer ID:"):
+                ready.set_result(None)
+
+        if not ready.done():
+            ready.set_exception(P2PDaemonError(f"Daemon failed to start: {last_line}"))
+
 
 
-class P2PInterruptedError(Exception):
+class P2PDaemonError(RuntimeError):
     pass
     pass
 
 
 
 

+ 1 - 1
hivemind/utils/networking.py

@@ -30,7 +30,7 @@ def strip_port(endpoint: Endpoint) -> Hostname:
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
     return endpoint[: endpoint.rindex(":")] if maybe_port.isdigit() or maybe_port == "*" else endpoint
 
 
 
 
-def find_open_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
+def get_free_port(params=(socket.AF_INET, socket.SOCK_STREAM), opt=(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)):
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     """Finds a tcp port that can be occupied with a socket with *params and use *opt options"""
     try:
     try:
         with closing(socket.socket(*params)) as sock:
         with closing(socket.socket(*params)) as sock:

+ 13 - 1
tests/test_p2p_daemon.py

@@ -9,8 +9,9 @@ import numpy as np
 import pytest
 import pytest
 from multiaddr import Multiaddr
 from multiaddr import Multiaddr
 
 
-from hivemind.p2p import P2P, P2PHandlerError
+from hivemind.p2p import P2P, P2PDaemonError, P2PHandlerError
 from hivemind.proto import dht_pb2
 from hivemind.proto import dht_pb2
+from hivemind.utils.networking import get_free_port
 from hivemind.utils.serializer import MSGPackSerializer
 from hivemind.utils.serializer import MSGPackSerializer
 
 
 
 
@@ -33,6 +34,17 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
     assert not is_process_running(child_pid)
 
 
 
 
+@pytest.mark.asyncio
+async def test_startup_error_message():
+    with pytest.raises(P2PDaemonError, match=r"Failed to connect to bootstrap peers"):
+        await P2P.create(
+            initial_peers=[f"/ip4/127.0.0.1/tcp/{get_free_port()}/p2p/QmdaK4LUeQaKhqSFPRu9N7MvXUEWDxWwtCvPrS444tCgd1"]
+        )
+
+    with pytest.raises(P2PDaemonError, match=r"Daemon failed to start in .+ seconds"):
+        await P2P.create(startup_timeout=0.1)  # Test that startup_timeout works
+
+
 @pytest.mark.parametrize(
 @pytest.mark.parametrize(
     "host_maddrs",
     "host_maddrs",
     [
     [

+ 1 - 1
tests/test_utils/dht_swarms.py

@@ -18,7 +18,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue, **kwargs):
         asyncio.set_event_loop(asyncio.new_event_loop())
         asyncio.set_event_loop(asyncio.new_event_loop())
     loop = asyncio.get_event_loop()
     loop = asyncio.get_event_loop()
 
 
-    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_attempts=10, **kwargs))
+    node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, **kwargs))
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
     maddrs = loop.run_until_complete(node.get_visible_maddrs())
 
 
     info_queue.put((node.node_id, node.peer_id, maddrs))
     info_queue.put((node.node_id, node.peer_id, maddrs))

+ 4 - 4
tests/test_utils/p2p_daemon.py

@@ -10,7 +10,7 @@ from typing import NamedTuple
 from multiaddr import Multiaddr, protocols
 from multiaddr import Multiaddr, protocols
 from pkg_resources import resource_filename
 from pkg_resources import resource_filename
 
 
-from hivemind import find_open_port
+from hivemind import get_free_port
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 from hivemind.p2p.p2p_daemon_bindings.p2pclient import Client
 
 
 TIMEOUT_DURATION = 30  # seconds
 TIMEOUT_DURATION = 30  # seconds
@@ -57,7 +57,7 @@ class Daemon:
 
 
     def _run(self):
     def _run(self):
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
         cmd_list = [P2PD_PATH, f"-listen={str(self.control_maddr)}"]
-        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"]
+        cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{get_free_port()}"]
         if self.enable_connmgr:
         if self.enable_connmgr:
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
             cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"]
         if self.enable_dht:
         if self.enable_dht:
@@ -129,8 +129,8 @@ async def make_p2pd_pair_unix(enable_control, enable_connmgr, enable_dht, enable
 
 
 @asynccontextmanager
 @asynccontextmanager
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
 async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub):
-    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
-    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}")
+    control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
+    listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{get_free_port()}")
     async with _make_p2pd_pair(
     async with _make_p2pd_pair(
         control_maddr=control_maddr,
         control_maddr=control_maddr,
         listen_maddr=listen_maddr,
         listen_maddr=listen_maddr,