Просмотр исходного кода

Fix race condition while reserving ports in P2P (#299)

hivemind.p2p.P2P now uses 3 ports for its operations (external port for other P2P instances + 2 ports for daemon-client interaction). These ports are chosen before running p2pd using find_open_ports().

This leads to a race condition because another P2P instance (or another program) may acquire the port just between it is chosen and p2pd is started. The race condition become extremely probable for tests and benchmarks launching many DHT instances concurrently (since each DHT launches its own P2P instance).

This PR fixes it and makes several refactorings along the way.
Alexander Borzunov 4 лет назад
Родитель
Сommit
3e21f75d9e

+ 1 - 1
hivemind/p2p/__init__.py

@@ -1,2 +1,2 @@
 from hivemind.p2p.p2p_daemon import P2P, P2PContext, P2PHandlerError
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo

+ 67 - 63
hivemind/p2p/p2p_daemon.py

@@ -1,4 +1,5 @@
 import asyncio
+import secrets
 from copy import deepcopy
 from dataclasses import dataclass
 from importlib.resources import path
@@ -10,7 +11,7 @@ from multiaddr import Multiaddr
 
 import hivemind.hivemind_cli as cli
 import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo
+from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo
 from hivemind.proto import p2pd_pb2
 from hivemind.utils import MSGPackSerializer
 from hivemind.utils.logging import get_logger
@@ -20,8 +21,6 @@ logger = get_logger(__name__)
 
 
 P2PD_FILENAME = 'p2pd'
-NUM_RETRIES = 3
-RETRY_DELAY = 0.4
 
 
 @dataclass(frozen=True)
@@ -63,20 +62,22 @@ class P2P:
     }
 
     def __init__(self):
+        self.id = None
         self._child = None
         self._alive = False
         self._listen_task = None
         self._server_stopped = asyncio.Event()
 
     @classmethod
-    async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True,
+    async def create(cls, *args, quic: bool = False, tls: bool = True, conn_manager: bool = True,
                      dht_mode: str = 'dht_server', force_reachability: Optional[str] = None,
                      nat_port_map: bool = True, auto_nat: bool = True,
-                     bootstrap_peers: Optional[List[Multiaddr]] = None,
-                     use_ipfs: bool = False, external_port: int = None,
-                     daemon_listen_port: int = None, use_relay: bool = True, use_relay_hop: bool = False,
+                     bootstrap_peers: Optional[List[Multiaddr]] = None, use_ipfs: bool = False,
+                     host_maddrs: Optional[List[Multiaddr]] = None,
+                     use_relay: bool = True, use_relay_hop: bool = False,
                      use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0,
-                     **kwargs) -> 'P2P':
+                     quiet: bool = True,
+                     ping_n_retries: int = 3, ping_retry_delay: float = 0.4, **kwargs) -> 'P2P':
         """
         Start a new p2pd process and connect to it.
         :param quic: Enables the QUIC transport
@@ -89,13 +90,13 @@ class P2P:
         :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled
         :param bootstrap_peers: List of bootstrap peers
         :param use_ipfs: Bootstrap to IPFS (works only if bootstrap=True and bootstrap_peers=None)
-        :param external_port: port for external connections from other p2p instances
-        :param daemon_listen_port: port for connection daemon and client binding
+        :param host_maddrs: multiaddresses for external connections from other p2p instances
         :param use_relay: enables circuit relay
         :param use_relay_hop: enables hop for relay
         :param use_relay_discovery: enables passive discovery for relay
         :param use_auto_relay: enables autorelay
         :param relay_hop_limit: sets the hop limit for hop relays
+        :param quiet: make the daemon process quiet
         :param args: positional CLI arguments for the p2p daemon
         :param kwargs: keyword CLI arguments for the p2p daemon
         :return: a wrapper for the p2p daemon
@@ -108,40 +109,59 @@ class P2P:
         with path(cli, P2PD_FILENAME) as p:
             p2pd_path = p
 
+        socket_uid = secrets.token_urlsafe(8)
+        self._daemon_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pd-{socket_uid}.sock')
+        self._client_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pclient-{socket_uid}.sock')
+
         need_bootstrap = bool(bootstrap_peers) or use_ipfs
         bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers)
         dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0})
         force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {})
+        host_maddrs = {'hostAddrs': ','.join(str(maddr) for maddr in host_maddrs)} if host_maddrs else {}
         proc_args = self._make_process_args(
             str(p2pd_path), *args,
+            listen=self._daemon_listen_maddr,
             quic=quic, tls=tls, connManager=conn_manager,
             natPortMap=nat_port_map, autonat=auto_nat,
             relay=use_relay, relayHop=use_relay_hop, relayDiscovery=use_relay_discovery,
             autoRelay=use_auto_relay, relayHopLimit=relay_hop_limit,
-            b=need_bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs})
-        self._assign_daemon_ports(external_port, daemon_listen_port)
+            b=need_bootstrap, q=quiet, **{**bootstrap_peers, **dht, **force_reachability, **host_maddrs, **kwargs})
+
+        self._initialize(proc_args)
+        await self._ping_daemon_with_retries(ping_n_retries, ping_retry_delay)
+
+        return self
+
+    def _initialize(self, proc_args: List[str]) -> None:
+        self._child = Popen(args=proc_args, encoding="utf8")
+        self._alive = True
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+
+    async def _ping_daemon_with_retries(self, ping_n_retries: int, ping_retry_delay: float) -> None:
+        for try_number in range(ping_n_retries):
+            await asyncio.sleep(ping_retry_delay * (2 ** try_number))
+
+            if self._child.poll() is not None:  # Process died
+                break
 
-        for try_count in range(NUM_RETRIES):
             try:
-                self._initialize(proc_args)
-                await self._wait_for_client(RETRY_DELAY * (2 ** try_count))
+                await self._ping_daemon()
                 break
             except Exception as e:
-                logger.debug(f"Failed to initialize p2p daemon: {e}")
-                self._terminate()
-                if try_count == NUM_RETRIES - 1:
+                if try_number == ping_n_retries - 1:
+                    logger.error(f'Failed to ping p2pd: {e}')
+                    await self.shutdown()
                     raise
-                self._assign_daemon_ports()
 
-        return self
+        if self._child.returncode is not None:
+            raise RuntimeError(f'The p2p daemon has died with return code {self._child.returncode}')
 
     @classmethod
-    async def replicate(cls, daemon_listen_port: int, external_port: int) -> 'P2P':
+    async def replicate(cls, daemon_listen_maddr: Multiaddr) -> 'P2P':
         """
         Connect to existing p2p daemon
-        :param daemon_listen_port: port for connection daemon and client binding
-        :param external_port: port for external connections from other p2p instances
-        :return: new wrapper for existing p2p daemon
+        :param daemon_listen_maddr: multiaddr of the existing p2p daemon
+        :return: new wrapper for the existing p2p daemon
         """
 
         self = cls()
@@ -149,14 +169,20 @@ class P2P:
         # Use external already running p2pd
         self._child = None
         self._alive = True
-        self._assign_daemon_ports(external_port, daemon_listen_port)
-        self._client_listen_port = find_open_port()
-        self._client = p2pclient.Client(
-            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
-            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
-        await self._wait_for_client()
+
+        socket_uid = secrets.token_urlsafe(8)
+        self._daemon_listen_maddr = daemon_listen_maddr
+        self._client_listen_maddr = Multiaddr(f'/unix/tmp/hivemind-p2pclient-{socket_uid}.sock')
+
+        self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)
+
+        await self._ping_daemon()
         return self
 
+    async def _ping_daemon(self) -> None:
+        self.id, maddrs = await self._client.identify()
+        logger.debug(f'Launched p2pd with id = {self.id}, host multiaddrs = {maddrs}')
+
     async def identify_maddrs(self) -> List[Multiaddr]:
         _, maddrs = await self._client.identify()
         if not maddrs:
@@ -165,6 +191,9 @@ class P2P:
         p2p_maddr = Multiaddr(f'/p2p/{self.id.to_base58()}')
         return [addr.encapsulate(p2p_maddr) for addr in maddrs]
 
+    async def list_peers(self) -> List[PeerInfo]:
+        return list(await self._client.list_peers())
+
     async def wait_for_at_least_n_peers(self, n_peers: int, attempts: int = 3, delay: float = 1) -> None:
         for _ in range(attempts):
             peers = await self._client.list_peers()
@@ -174,36 +203,9 @@ class P2P:
 
         raise RuntimeError('Not enough peers')
 
-    def _initialize(self, proc_args: List[str]) -> None:
-        proc_args = deepcopy(proc_args)
-        proc_args.extend(self._make_process_args(
-            hostAddrs=f'/ip4/0.0.0.0/tcp/{self._external_port},/ip4/0.0.0.0/udp/{self._external_port}/quic',
-            listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'
-        ))
-        self._child = Popen(args=proc_args, encoding="utf8")
-        self._alive = True
-        self._client_listen_port = find_open_port()
-        self._client = p2pclient.Client(
-            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'),
-            Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}'))
-
-    async def _wait_for_client(self, delay: float = 0) -> None:
-        await asyncio.sleep(delay)
-        self.id, _ = await self._client.identify()
-
-    def _assign_daemon_ports(self, external_port: int = None, daemon_listen_port: int = None) -> None:
-        if external_port is None:
-            external_port = find_open_port()
-        if daemon_listen_port is None:
-            daemon_listen_port = find_open_port()
-            while daemon_listen_port == external_port:
-                daemon_listen_port = find_open_port()
-
-        self._external_port, self._daemon_listen_port = external_port, daemon_listen_port
-
     @property
-    def external_port(self) -> int:
-        return self._external_port
+    def daemon_listen_maddr(self) -> Multiaddr:
+        return self._daemon_listen_maddr
 
     @staticmethod
     async def send_raw_data(data: bytes, writer: asyncio.StreamWriter) -> None:
@@ -314,14 +316,14 @@ class P2P:
 
         return do_handle_unary_stream
 
-    def start_listening(self) -> None:
+    def _start_listening(self) -> None:
         async def listen() -> None:
             async with self._client.listen():
                 await self._server_stopped.wait()
 
         self._listen_task = asyncio.create_task(listen())
 
-    async def stop_listening(self) -> None:
+    async def _stop_listening(self) -> None:
         if self._listen_task is not None:
             self._server_stopped.set()
             self._listen_task.cancel()
@@ -333,13 +335,13 @@ class P2P:
 
     async def add_stream_handler(self, name: str, handle: Callable[[bytes], bytes]) -> None:
         if self._listen_task is None:
-            self.start_listening()
+            self._start_listening()
         await self._client.stream_handler(name, self._handle_stream(handle))
 
     async def add_unary_handler(self, name: str, handle: Callable[[Any, P2PContext], Any],
                                 in_proto_type: type, out_proto_type: type) -> None:
         if self._listen_task is None:
-            self.start_listening()
+            self._start_listening()
         await self._client.stream_handler(
             name, self._handle_unary_stream(handle, name, in_proto_type, out_proto_type))
 
@@ -372,6 +374,7 @@ class P2P:
         return self._alive
 
     async def shutdown(self) -> None:
+        await self._stop_listening()
         await asyncio.get_event_loop().run_in_executor(None, self._terminate)
 
     def _terminate(self) -> None:
@@ -379,6 +382,7 @@ class P2P:
         if self._child is not None and self._child.poll() is None:
             self._child.terminate()
             self._child.wait()
+            logger.debug(f'Terminated p2pd with id = {self.id}')
 
     @staticmethod
     def _make_process_args(*args, **kwargs) -> List[str]:

+ 2 - 2
hivemind/p2p/p2p_daemon_bindings/control.py

@@ -105,10 +105,10 @@ class ControlClient:
             )
 
         async with server:
-            logger.info(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
+            logger.debug(f"DaemonConnector {self} starts listening to {self.listen_maddr}")
             yield self
 
-        logger.info(f"DaemonConnector {self} closed")
+        logger.debug(f"DaemonConnector {self} closed")
 
     async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
         reader, writer = await self.daemon_connector.open_connection()

+ 45 - 29
tests/test_p2p_daemon.py

@@ -1,5 +1,6 @@
 import asyncio
 import multiprocessing as mp
+import socket
 import subprocess
 from functools import partial
 from typing import List
@@ -9,10 +10,10 @@ import pytest
 import torch
 from multiaddr import Multiaddr
 
-from hivemind.p2p import P2P, P2PHandlerError
-from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID
+from hivemind.p2p import P2P, P2PHandlerError, PeerID, PeerInfo
 from hivemind.proto import dht_pb2, runtime_pb2
 from hivemind.utils import MSGPackSerializer
+from hivemind.utils.networking import find_open_port
 from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor
 
 
@@ -21,11 +22,7 @@ def is_process_running(pid: int) -> bool:
 
 
 async def replicate_if_needed(p2p: P2P, replicate: bool) -> P2P:
-    return await P2P.replicate(p2p._daemon_listen_port, p2p.external_port) if replicate else p2p
-
-
-def bootstrap_addr(external_port: int, id_: str) -> Multiaddr:
-    return Multiaddr(f'/ip4/127.0.0.1/tcp/{external_port}/p2p/{id_}')
+    return await P2P.replicate(p2p.daemon_listen_maddr) if replicate else p2p
 
 
 async def bootstrap_from(daemons: List[P2P]) -> List[Multiaddr]:
@@ -46,26 +43,51 @@ async def test_daemon_killed_on_del():
     assert not is_process_running(child_pid)
 
 
+@pytest.mark.asyncio
+async def test_error_for_wrong_daemon_arguments():
+    with pytest.raises(RuntimeError):
+        await P2P.create(unknown_argument=True)
+
+
 @pytest.mark.asyncio
 async def test_server_client_connection():
     server = await P2P.create()
-    peers = await server._client.list_peers()
+    peers = await server.list_peers()
     assert len(peers) == 0
 
     nodes = await bootstrap_from([server])
     client = await P2P.create(bootstrap_peers=nodes)
     await client.wait_for_at_least_n_peers(1)
 
-    peers = await client._client.list_peers()
+    peers = await client.list_peers()
     assert len(peers) == 1
-    peers = await server._client.list_peers()
+    peers = await server.list_peers()
+    assert len(peers) == 1
+
+
+@pytest.mark.asyncio
+async def test_quic_transport():
+    server_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
+    server = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{server_port}/quic')])
+    peers = await server.list_peers()
+    assert len(peers) == 0
+
+    nodes = await bootstrap_from([server])
+    client_port = find_open_port((socket.AF_INET, socket.SOCK_DGRAM))
+    client = await P2P.create(quic=True, host_maddrs=[Multiaddr(f'/ip4/127.0.0.1/udp/{client_port}/quic')],
+                              bootstrap_peers=nodes)
+    await client.wait_for_at_least_n_peers(1)
+
+    peers = await client.list_peers()
+    assert len(peers) == 1
+    peers = await server.list_peers()
     assert len(peers) == 1
 
 
 @pytest.mark.asyncio
 async def test_daemon_replica_does_not_affect_primary():
     p2p_daemon = await P2P.create()
-    p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon.external_port)
+    p2p_replica = await P2P.replicate(p2p_daemon.daemon_listen_maddr)
 
     child_pid = p2p_daemon._child.pid
     assert is_process_running(child_pid)
@@ -140,7 +162,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
             nonlocal handler_cancelled
             handler_cancelled = True
         return dht_pb2.PingResponse(
-            peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server.external_port),
+            peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
             sender_endpoint=context.handle_name, available=True)
 
     server_pid = server_primary._child.pid
@@ -156,10 +178,10 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
     await client.wait_for_at_least_n_peers(1)
 
     ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client.external_port),
+        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
         validate=True)
     expected_response = dht_pb2.PingResponse(
-        peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes(), rpc_port=server.external_port),
+        peer=dht_pb2.NodeInfo(node_id=server.id.to_bytes()),
         sender_endpoint=handle_name, available=True)
 
     if should_cancel:
@@ -174,7 +196,7 @@ async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"
         assert actual_response == expected_response
         assert not handler_cancelled
 
-    await server.stop_listening()
+    await server.shutdown()
     await server_primary.shutdown()
     assert not is_process_running(server_pid)
 
@@ -199,14 +221,13 @@ async def test_call_unary_handler_error(handle_name="handle"):
     await client.wait_for_at_least_n_peers(1)
 
     ping_request = dht_pb2.PingRequest(
-        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes(), rpc_port=client.external_port),
+        peer=dht_pb2.NodeInfo(node_id=client.id.to_bytes()),
         validate=True)
 
     with pytest.raises(P2PHandlerError) as excinfo:
         await client.call_unary_handler(server.id, handle_name, ping_request, dht_pb2.PingResponse)
     assert 'boom' in str(excinfo.value)
 
-    await server.stop_listening()
     await server.shutdown()
     await client.shutdown()
 
@@ -239,7 +260,6 @@ async def test_call_peer_single_process(test_input, expected, handle, handler_na
     result = MSGPackSerializer.loads(result_msgp)
     assert result == expected
 
-    await server.stop_listening()
     await server.shutdown()
     assert not is_process_running(server_pid)
 
@@ -254,11 +274,10 @@ async def run_server(handler_name, server_side, client_side, response_received):
     assert is_process_running(server_pid)
 
     server_side.send(server.id)
-    server_side.send(server.external_port)
+    server_side.send(await server.identify_maddrs())
     while response_received.value == 0:
         await asyncio.sleep(0.5)
 
-    await server.stop_listening()
     await server.shutdown()
     assert not is_process_running(server_pid)
 
@@ -280,10 +299,9 @@ async def test_call_peer_different_processes():
     proc.start()
 
     peer_id = client_side.recv()
-    peer_port = client_side.recv()
+    peer_maddrs = client_side.recv()
 
-    nodes = [bootstrap_addr(peer_port, peer_id)]
-    client = await P2P.create(bootstrap_peers=nodes)
+    client = await P2P.create(bootstrap_peers=peer_maddrs)
     client_pid = client._child.pid
     assert is_process_running(client_pid)
 
@@ -328,7 +346,6 @@ async def test_call_peer_torch_square(test_input, expected, handler_name="handle
     result = deserialize_torch_tensor(result)
     assert torch.allclose(result, expected)
 
-    await server.stop_listening()
     await server.shutdown()
     await client.shutdown()
 
@@ -361,7 +378,6 @@ async def test_call_peer_torch_add(test_input, expected, handler_name="handle"):
     result = deserialize_torch_tensor(result)
     assert torch.allclose(result, expected)
 
-    await server.stop_listening()
     await server.shutdown()
     await client.shutdown()
 
@@ -390,9 +406,10 @@ async def test_call_peer_error(replicate, handler_name="handle"):
     result = await client.call_peer_handler(server.id, handler_name, inp_msgp)
     assert result == b'something went wrong :('
 
-    await server.stop_listening()
     await server_primary.shutdown()
+    await server.shutdown()
     await client_primary.shutdown()
+    await client.shutdown()
 
 
 @pytest.mark.asyncio
@@ -423,8 +440,8 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
     result = await client.call_peer_handler(server_id, handler_name + '2', b'3')
     assert result == b"replica2"
 
-    await server_replica1.stop_listening()
-    await server_replica2.stop_listening()
+    await server_replica1.shutdown()
+    await server_replica2.shutdown()
 
     # Primary does not handle replicas protocols
     with pytest.raises(Exception):
@@ -432,6 +449,5 @@ async def test_handlers_on_different_replicas(handler_name="handle"):
     with pytest.raises(Exception):
         await client.call_peer_handler(server_id, handler_name + '2', b'')
 
-    await server_primary.stop_listening()
     await server_primary.shutdown()
     await client.shutdown()