فهرست منبع

Add service checking direct reachability from peers (#195)

Servers joining from behind NATs/firewalls usually take several minutes to join a libp2p relay before they become accessible from the outside Internet. Moreover, requests to such servers are slower and more likely to fail (e.g., if the server switches a relay at the moment). If such servers host certain DHT keys, the swarm may occasionally lose read/write access to these keys, which results in:

- Clients being unable to find any servers hosting a certain block.
- All servers starting rebalancing to the same place to close the alleged "gap" in the swarm.

This PRs modifies servers so that DHT keys are only hosted on **directly reachable** servers (the ones who aren't behind NAT/firewall). This way, DHT becomes more stable and works faster. Of course, trhe servers behind NATs/firewalls still accept requests for running inference/forward/backward for blocks they hold (it's more acceptable for this kind of requests to be slower or fail).

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
justheuristic 2 سال پیش
والد
کامیت
771ca590e7
4فایلهای تغییر یافته به همراه254 افزوده شده و 10 حذف شده
  1. 104 0
      src/petals/cli/run_dht.py
  2. 3 0
      src/petals/constants.py
  3. 131 4
      src/petals/server/reachability.py
  4. 16 6
      src/petals/server/server.py

+ 104 - 0
src/petals/cli/run_dht.py

@@ -0,0 +1,104 @@
+"""
+A copy of run_dht.py from hivemind with the ReachabilityProtocol added:
+https://github.com/learning-at-home/hivemind/blob/master/hivemind/hivemind_cli/run_dht.py
+
+This script may be used for launching lightweight CPU machines serving as bootstrap nodes to a Petals swarm.
+
+This may be eventually merged to the hivemind upstream.
+"""
+
+import time
+from argparse import ArgumentParser
+from secrets import token_hex
+
+from hivemind.dht import DHT, DHTNode
+from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from hivemind.utils.networking import log_visible_maddrs
+
+from petals.server.reachability import ReachabilityProtocol
+
+use_hivemind_log_handler("in_root_logger")
+logger = get_logger(__name__)
+
+
+async def report_status(dht: DHT, node: DHTNode):
+    logger.info(
+        f"{len(node.protocol.routing_table.uid_to_peer_id) + 1} DHT nodes (including this one) "
+        f"are in the local routing table "
+    )
+    logger.debug(f"Routing table contents: {node.protocol.routing_table}")
+    logger.info(f"Local storage contains {len(node.protocol.storage)} keys")
+    logger.debug(f"Local storage contents: {node.protocol.storage}")
+
+    # Contact peers and keep the routing table healthy (remove stale PeerIDs)
+    await node.get(f"heartbeat_{token_hex(16)}", latest=True)
+
+
+def main():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--initial_peers",
+        nargs="*",
+        help="Multiaddrs of the peers that will welcome you into the existing DHT. "
+        "Example: /ip4/203.0.113.1/tcp/31337/p2p/XXXX /ip4/203.0.113.2/tcp/7777/p2p/YYYY",
+    )
+    parser.add_argument(
+        "--host_maddrs",
+        nargs="*",
+        default=["/ip4/0.0.0.0/tcp/0", "/ip6/::/tcp/0"],
+        help="Multiaddrs to listen for external connections from other DHT instances. "
+        "Defaults to all IPv4 interfaces and the TCP protocol: /ip4/0.0.0.0/tcp/0",
+    )
+    parser.add_argument(
+        "--announce_maddrs",
+        nargs="*",
+        help="Visible multiaddrs the host announces for external connections from other DHT instances",
+    )
+    parser.add_argument(
+        "--use_ipfs",
+        action="store_true",
+        help='Use IPFS to find initial_peers. If enabled, you only need to provide the "/p2p/XXXX" '
+        "part of the multiaddrs for the initial_peers "
+        "(no need to specify a particular IPv4/IPv6 host and port)",
+    )
+    parser.add_argument(
+        "--identity_path",
+        help="Path to a private key file. If defined, makes the peer ID deterministic. "
+        "If the file does not exist, writes a new private key to this file.",
+    )
+    parser.add_argument(
+        "--no_relay",
+        action="store_false",
+        dest="use_relay",
+        help="Disable circuit relay functionality in libp2p (see https://docs.libp2p.io/concepts/nat/circuit-relay/)",
+    )
+    parser.add_argument(
+        "--use_auto_relay", action="store_true", help="Look for libp2p relays to reach peers behind NATs/firewalls"
+    )
+    parser.add_argument(
+        "--refresh_period", type=int, default=30, help="Period (in seconds) for fetching the keys from DHT"
+    )
+
+    args = parser.parse_args()
+
+    dht = DHT(
+        start=True,
+        initial_peers=args.initial_peers,
+        host_maddrs=args.host_maddrs,
+        announce_maddrs=args.announce_maddrs,
+        use_ipfs=args.use_ipfs,
+        identity_path=args.identity_path,
+        use_relay=args.use_relay,
+        use_auto_relay=args.use_auto_relay,
+    )
+    log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
+
+    reachability_protocol = ReachabilityProtocol.attach_to_dht(dht, await_ready=True)
+
+    while True:
+        dht.run_coroutine(report_status, return_future=False)
+        time.sleep(args.refresh_period)
+
+
+if __name__ == "__main__":
+    main()

+ 3 - 0
src/petals/constants.py

@@ -4,3 +4,6 @@ PUBLIC_INITIAL_PEERS = [
     "/dns/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
     "/dns6/bootstrap2.petals.ml/tcp/31338/p2p/QmQGTqmM7NKjV6ggU1ZCap8zWiyKR89RViDXiqehSiCpY5",
 ]
+
+# The reachability API is currently used only when connecting to the public swarm
+REACHABILITY_API_URL = "http://health.petals.ml"

+ 131 - 4
src/petals/server/reachability.py

@@ -1,16 +1,30 @@
+import asyncio
 import math
+import threading
 import time
+from concurrent.futures import Future
+from contextlib import asynccontextmanager
+from functools import partial
+from secrets import token_hex
+from typing import Optional
 
 import requests
-from hivemind.utils.logging import get_logger
+from hivemind.dht import DHT, DHTNode
+from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
+from hivemind.p2p import P2P, P2PContext, PeerID, ServicerBase
+from hivemind.proto import dht_pb2
+from hivemind.utils import get_logger
 
-logger = get_logger(__file__)
+from petals.constants import REACHABILITY_API_URL
 
+logger = get_logger(__name__)
 
-def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
+
+def validate_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
+    """verify that your peer is reachable from a (centralized) validator, whether directly or through a relay"""
     for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
         try:
-            r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
+            r = requests.get(f"{REACHABILITY_API_URL}/api/v1/is_reachable/{peer_id}", timeout=10)
             r.raise_for_status()
             response = r.json()
 
@@ -37,3 +51,116 @@ def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float =
         f"        python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
         f"    4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
     )
+
+
+def check_direct_reachability(max_peers: int = 5, threshold: float = 0.5, **kwargs) -> Optional[bool]:
+    """test if your peer is accessible by others in the swarm with the specified network options in **kwargs"""
+
+    async def _check_direct_reachability():
+        target_dht = await DHTNode.create(client_mode=True, **kwargs)
+        try:
+            protocol = ReachabilityProtocol(probe=target_dht.protocol.p2p)
+            async with protocol.serve(target_dht.protocol.p2p):
+                successes = requests = 0
+                for remote_peer in list(target_dht.protocol.routing_table.peer_id_to_uid.keys()):
+                    probe_available = await protocol.call_check(remote_peer=remote_peer, check_peer=target_dht.peer_id)
+                    if probe_available is None:
+                        continue  # remote peer failed to check probe
+                    successes += probe_available
+                    requests += 1
+                    if requests >= max_peers:
+                        break
+
+            logger.info(f"Direct reachability: {successes}/{requests}")
+            return (successes / requests) >= threshold if requests > 0 else None
+        finally:
+            await target_dht.shutdown()
+
+    return RemoteExpertWorker.run_coroutine(_check_direct_reachability())
+
+
+STRIPPED_PROBE_ARGS = dict(
+    dht_mode="client", use_relay=False, auto_nat=False, nat_port_map=False, no_listen=True, startup_timeout=60
+)
+
+
+class ReachabilityProtocol(ServicerBase):
+    """Mini protocol to test if a locally running peer is accessible by other devices in the swarm"""
+
+    def __init__(self, *, probe: Optional[P2P] = None, wait_timeout: float = 5.0):
+        self.probe = probe
+        self.wait_timeout = wait_timeout
+        self._event_loop = self._stop = None
+
+    async def call_check(self, remote_peer: PeerID, *, check_peer: PeerID) -> Optional[bool]:
+        """Returns True if remote_peer can reach check_peer, False if it cannot, None if it did not respond"""
+        try:
+            request = dht_pb2.PingRequest(peer=dht_pb2.NodeInfo(node_id=check_peer.to_bytes()))
+            timeout = self.wait_timeout if check_peer == remote_peer else self.wait_timeout * 2
+            response = await self.get_stub(self.probe, remote_peer).rpc_check(request, timeout=timeout)
+            logger.debug(f"call_check(remote_peer={remote_peer}, check_peer={check_peer}) -> {response.available}")
+            return response.available
+        except Exception as e:
+            logger.debug(f"Requested {remote_peer} to check {check_peer}, but got:", exc_info=True)
+            return None
+
+    async def rpc_check(self, request: dht_pb2.PingRequest, context: P2PContext) -> dht_pb2.PingResponse:
+        """Help another peer to check its reachability"""
+        response = dht_pb2.PingResponse(available=True)
+        check_peer = PeerID(request.peer.node_id)
+        if check_peer != context.local_id:  # remote peer wants us to check someone other than ourselves
+            response.available = await self.call_check(check_peer, check_peer=check_peer) is True
+            logger.info(
+                f"reachability.rpc_check(remote_peer=...{str(context.remote_id)[-6:]}, "
+                f"check_peer=...{str(check_peer)[-6:]}) -> {response.available}"
+            )
+        return response
+
+    @asynccontextmanager
+    async def serve(self, p2p: P2P):
+        try:
+            await self.add_p2p_handlers(p2p)
+            yield self
+        finally:
+            await self.remove_p2p_handlers(p2p)
+
+    @classmethod
+    def attach_to_dht(cls, dht: DHT, await_ready: bool = False, **kwargs) -> Optional["ReachabilityProtocol"]:
+        protocol = cls(**kwargs)
+        ready = Future()
+
+        async def _serve_with_probe():
+            try:
+                common_p2p = await dht.replicate_p2p()
+                protocol._event_loop = asyncio.get_event_loop()
+                protocol._stop = asyncio.Event()
+
+                initial_peers = [str(addr) for addr in await common_p2p.get_visible_maddrs(latest=True)]
+                for info in await common_p2p.list_peers():
+                    initial_peers.extend(f"{addr}/p2p/{info.peer_id}" for addr in info.addrs)
+                protocol.probe = await P2P.create(initial_peers, **STRIPPED_PROBE_ARGS)
+
+                ready.set_result(True)
+                logger.info("Reachability service started")
+
+                async with protocol.serve(common_p2p):
+                    await protocol._stop.wait()
+            except Exception as e:
+                logger.warning(f"Reachability service failed: {repr(e)}")
+                logger.debug("See detailed traceback below:", exc_info=True)
+
+                if not ready.done():
+                    ready.set_exception(e)
+            finally:
+                if protocol is not None and protocol.probe is not None:
+                    await protocol.probe.shutdown()
+                logger.debug("Reachability service shut down")
+
+        threading.Thread(target=partial(asyncio.run, _serve_with_probe()), daemon=True).start()
+        if await_ready:
+            ready.result()  # Propagates startup exceptions, if any
+        return protocol
+
+    def shutdown(self):
+        if self._event_loop is not None and self._stop is not None:
+            self._event_loop.call_soon_threadsafe(self._stop.set)

+ 16 - 6
src/petals/server/server.py

@@ -26,7 +26,7 @@ from petals.server.backend import TransformerBackend
 from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
-from petals.server.reachability import check_reachability
+from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.throughput import get_dtype_name, get_host_throughput
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@@ -77,6 +77,7 @@ class Server:
         load_in_8bit: Optional[bool] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
+        dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_auto_relay: bool = True,
         **kwargs,
@@ -118,20 +119,27 @@ class Server:
         )
         self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
+        if dht_client_mode is None:
+            is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
+            dht_client_mode = is_reachable is False  # if could not check reachability (returns None), run a full peer
+            logger.info(f"This server will run DHT in {'client' if dht_client_mode else 'full peer'} mode")
         self.dht = DHT(
             initial_peers=initial_peers,
             start=True,
             num_workers=self.block_config.n_layer,
             use_relay=use_relay,
             use_auto_relay=use_auto_relay,
+            client_mode=dht_client_mode,
             **kwargs,
         )
+        self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not dht_client_mode else None
+
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
             logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
         else:
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
-        self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
+        self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
 
         if device is None:
             device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -277,7 +285,7 @@ class Server:
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
-                need_reachability_check=self.need_reachability_check,
+                should_validate_reachability=self.should_validate_reachability,
                 start=True,
             )
             try:
@@ -335,6 +343,8 @@ class Server:
     def shutdown(self):
         self.stop.set()
 
+        if self.reachability_protocol is not None:
+            self.reachability_protocol.shutdown()
         self.dht.shutdown()
         self.dht.join()
 
@@ -367,7 +377,7 @@ class ModuleContainer(threading.Thread):
         use_auth_token: Optional[str],
         load_in_8bit: bool,
         tensor_parallel_devices: Sequence[torch.device],
-        need_reachability_check: bool,
+        should_validate_reachability: bool,
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@@ -422,8 +432,8 @@ class ModuleContainer(threading.Thread):
                     max_batch_size=max_batch_size,
                 )
 
-            if need_reachability_check:
-                check_reachability(dht.peer_id)
+            if should_validate_reachability:
+                validate_reachability(dht.peer_id)
         except:
             logger.debug("Shutting down backends")
             for backend in blocks.values():