|
@@ -10,7 +10,6 @@ from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import numpy as np
|
|
|
import psutil
|
|
|
-import requests
|
|
|
import torch
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
@@ -28,6 +27,7 @@ from petals.server.backend import TransformerBackend
|
|
|
from petals.server.block_utils import get_block_size
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
+from petals.server.reachability import check_reachability
|
|
|
from petals.server.throughput import get_host_throughput
|
|
|
from petals.utils.convert_block import check_device_balance, convert_block
|
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
@@ -78,6 +78,8 @@ class Server:
|
|
|
load_in_8bit: Optional[bool] = None,
|
|
|
tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
|
|
skip_reachability_check: bool = False,
|
|
|
+ use_relay: bool = True,
|
|
|
+ use_auto_relay: bool = True,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
@@ -117,14 +119,20 @@ class Server:
|
|
|
)
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
- self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
|
+ self.dht = DHT(
|
|
|
+ initial_peers=initial_peers,
|
|
|
+ start=True,
|
|
|
+ num_workers=self.block_config.n_layer,
|
|
|
+ use_relay=use_relay,
|
|
|
+ use_auto_relay=use_auto_relay,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
|
logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
|
|
|
- if not skip_reachability_check:
|
|
|
- self._check_reachability()
|
|
|
else:
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
+ self.need_reachability_check = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
|
|
|
|
|
if device is None:
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -196,35 +204,14 @@ class Server:
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
- def _check_reachability(self):
|
|
|
- try:
|
|
|
- r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
|
|
|
- r.raise_for_status()
|
|
|
- response = r.json()
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
|
|
|
- return
|
|
|
-
|
|
|
- if not response["success"]:
|
|
|
- # This happens only if health.petals.ml is up and explicitly told us that we are unreachable
|
|
|
- raise RuntimeError(
|
|
|
- f"Server is not reachable from the Internet:\n\n"
|
|
|
- f"{response['message']}\n\n"
|
|
|
- f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
|
|
- f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
|
|
- f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
|
|
- f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
|
|
- f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
|
|
- f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
|
|
- )
|
|
|
-
|
|
|
- logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
|
|
|
-
|
|
|
def _choose_num_blocks(self) -> int:
|
|
|
assert (
|
|
|
self.converted_model_name_or_path == "bigscience/bloom-petals"
|
|
|
), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
|
|
|
- assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
|
|
|
+ assert self.device.type == "cuda", (
|
|
|
+ "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
|
|
+ "CPU-only servers in the public swarm are discouraged since they are much slower"
|
|
|
+ )
|
|
|
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
|
|
|
|
|
if num_devices > 1:
|
|
@@ -287,6 +274,7 @@ class Server:
|
|
|
use_auth_token=self.use_auth_token,
|
|
|
load_in_8bit=self.load_in_8bit,
|
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
+ need_reachability_check=self.need_reachability_check,
|
|
|
start=True,
|
|
|
)
|
|
|
try:
|
|
@@ -380,6 +368,7 @@ class ModuleContainer(threading.Thread):
|
|
|
use_auth_token: Optional[str],
|
|
|
load_in_8bit: bool,
|
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
|
+ need_reachability_check: bool,
|
|
|
**kwargs,
|
|
|
) -> ModuleContainer:
|
|
|
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
@@ -433,6 +422,9 @@ class ModuleContainer(threading.Thread):
|
|
|
min_batch_size=min_batch_size,
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|
|
|
+
|
|
|
+ if need_reachability_check:
|
|
|
+ check_reachability(dht.peer_id)
|
|
|
except:
|
|
|
logger.debug("Shutting down backends")
|
|
|
for backend in blocks.values():
|