Sfoglia il codice sorgente

Support libp2p relays for NAT traversal (#186)

- Added relay options to servers
- Enabled relay options by default
- Changed hivemind version to 1.1.5
- Moved reachability check to be performed after blocks are loaded

Co-authored-by: Alexander Borzunov <borzunov.alexander@gmail.com>
Egiazarian Vage 2 anni fa
parent
commit
93bed7da5a

+ 1 - 1
setup.cfg

@@ -37,7 +37,7 @@ install_requires =
     huggingface-hub==0.11.1
     transformers==4.25.1
     speedtest-cli==2.1.3
-    hivemind==1.1.3
+    hivemind==1.1.5
     tensor_parallel==1.0.23
     humanfriendly
     async-timeout>=4.0.2

+ 3 - 0
src/petals/cli/run_server.py

@@ -38,6 +38,9 @@ def main():
                              'This is a simplified way to set the --announce_maddrs option (see below).'
                              'Default: server announces IPv4/IPv6 addresses of your network interfaces')
 
+    parser.add_argument("--no_auto_relay", action="store_false", dest="use_auto_relay",
+                        help="Do not look for libp2p relays to reach peers behind NATs/firewalls")
+
     parser.add_argument('--host_maddrs', nargs='+', required=False,
                         help='Multiaddrs to listen for external connections from other peers')
     parser.add_argument('--announce_maddrs', nargs='+', required=False,

+ 2 - 0
src/petals/client/remote_model.py

@@ -107,6 +107,8 @@ class DistributedBloomModel(_LowCPUMemoryMixin, BloomModel):
                 num_workers=n_layer,
                 startup_timeout=config.daemon_startup_timeout,
                 start=True,
+                use_relay=True,
+                use_auto_relay=True,
             )
         )
         assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"

+ 39 - 0
src/petals/server/reachability.py

@@ -0,0 +1,39 @@
+import math
+import time
+
+import requests
+from hivemind.utils.logging import get_logger
+
+logger = get_logger(__file__)
+
+
+def check_reachability(peer_id, wait_time: float = 7 * 60, retry_delay: float = 15) -> None:
+    for attempt_no in range(math.floor(wait_time / retry_delay) + 1):
+        try:
+            r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{peer_id}", timeout=10)
+            r.raise_for_status()
+            response = r.json()
+
+            if response["success"]:
+                logger.info("Server is reachable from the Internet. It will appear at http://health.petals.ml soon")
+                return
+
+            if attempt_no == 0:
+                # Usually, libp2p manages to set up relays before we finish loading blocks.
+                # In other cases, we may need to wait for up to `wait_time` seconds before it's done.
+                logger.info("Detected a NAT or a firewall, connecting to libp2p relays. This takes a few minutes")
+            time.sleep(retry_delay)
+        except Exception as e:
+            logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
+            return
+
+    raise RuntimeError(
+        f"Server has not become reachable from the Internet:\n\n"
+        f"{response['message']}\n\n"
+        f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
+        f"    1. Choose a specific port for the Petals server, for example, 31337.\n"
+        f"    2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
+        f"    3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
+        f"        python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
+        f"    4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
+    )

+ 21 - 29
src/petals/server/server.py

@@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union
 
 import numpy as np
 import psutil
-import requests
 import torch
 from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
 from hivemind.moe.server.layers import add_custom_models_from_file
@@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend
 from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
+from petals.server.reachability import check_reachability
 from petals.server.throughput import get_host_throughput
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
@@ -78,6 +78,8 @@ class Server:
         load_in_8bit: Optional[bool] = None,
         tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
         skip_reachability_check: bool = False,
+        use_relay: bool = True,
+        use_auto_relay: bool = True,
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -117,14 +119,20 @@ class Server:
         )
         self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
-        self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
+        self.dht = DHT(
+            initial_peers=initial_peers,
+            start=True,
+            num_workers=self.block_config.n_layer,
+            use_relay=use_relay,
+            use_auto_relay=use_auto_relay,
+            **kwargs,
+        )
         visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
         if initial_peers == PUBLIC_INITIAL_PEERS:
             logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
-            if not skip_reachability_check:
-                self._check_reachability()
         else:
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
+        self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
 
         if device is None:
             device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -196,35 +204,14 @@ class Server:
 
         self.stop = threading.Event()
 
-    def _check_reachability(self):
-        try:
-            r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
-            r.raise_for_status()
-            response = r.json()
-        except Exception as e:
-            logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
-            return
-
-        if not response["success"]:
-            # This happens only if health.petals.ml is up and explicitly told us that we are unreachable
-            raise RuntimeError(
-                f"Server is not reachable from the Internet:\n\n"
-                f"{response['message']}\n\n"
-                f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
-                f"    1. Choose a specific port for the Petals server, for example, 31337.\n"
-                f"    2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
-                f"    3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
-                f"        python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
-                f"    4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
-            )
-
-        logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
-
     def _choose_num_blocks(self) -> int:
         assert (
             self.converted_model_name_or_path == "bigscience/bloom-petals"
         ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
-        assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
+        assert self.device.type == "cuda", (
+            "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
+            "CPU-only servers in the public swarm are discouraged since they are much slower"
+        )
         num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
 
         if num_devices > 1:
@@ -287,6 +274,7 @@ class Server:
                 use_auth_token=self.use_auth_token,
                 load_in_8bit=self.load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
+                need_reachability_check=self.need_reachability_check,
                 start=True,
             )
             try:
@@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread):
         use_auth_token: Optional[str],
         load_in_8bit: bool,
         tensor_parallel_devices: Sequence[torch.device],
+        need_reachability_check: bool,
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
@@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread):
                     min_batch_size=min_batch_size,
                     max_batch_size=max_batch_size,
                 )
+
+            if need_reachability_check:
+                check_reachability(dht.peer_id)
         except:
             logger.debug("Shutting down backends")
             for backend in blocks.values():