|
@@ -6,33 +6,36 @@ import multiprocessing as mp
|
|
|
import random
|
|
|
import threading
|
|
|
import time
|
|
|
-from typing import Dict, List, Optional, Union
|
|
|
+from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
-import numpy as np
|
|
|
-import psutil
|
|
|
-import requests
|
|
|
+import hivemind
|
|
|
import torch
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
-from transformers import BloomConfig
|
|
|
+from transformers import PretrainedConfig
|
|
|
|
|
|
-from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
-from petals.constants import PUBLIC_INITIAL_PEERS
|
|
|
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
|
+import petals
|
|
|
+from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
|
|
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
|
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
|
from petals.server import block_selection
|
|
|
-from petals.server.backend import TransformerBackend
|
|
|
-from petals.server.block_utils import get_block_size
|
|
|
+from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
|
+from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
|
+from petals.server.from_pretrained import load_pretrained_block
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
from petals.server.memory_cache import MemoryCache
|
|
|
-from petals.server.throughput import get_host_throughput
|
|
|
-from petals.utils.convert_8bit import replace_8bit_linear
|
|
|
-from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
+from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
|
+from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
+from petals.utils.auto_config import AutoDistributedConfig
|
|
|
+from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
+from petals.utils.ping import PingAggregator
|
|
|
+from petals.utils.random import sample_up_to
|
|
|
+from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
-logger = get_logger(__file__)
|
|
|
+logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
class Server:
|
|
@@ -45,26 +48,28 @@ class Server:
|
|
|
self,
|
|
|
*,
|
|
|
initial_peers: List[str],
|
|
|
- prefix: Optional[str],
|
|
|
+ dht_prefix: Optional[str],
|
|
|
converted_model_name_or_path: str,
|
|
|
+ public_name: Optional[str] = None,
|
|
|
throughput: Union[float, str],
|
|
|
num_blocks: Optional[int] = None,
|
|
|
block_indices: Optional[str] = None,
|
|
|
num_handlers: int = 8,
|
|
|
+ inference_max_length: Optional[int] = None,
|
|
|
min_batch_size: int = 1,
|
|
|
- max_batch_size: int = 2048,
|
|
|
- inference_max_length: int = 2048,
|
|
|
+ max_batch_size: Optional[int] = None,
|
|
|
+ max_chunk_size_bytes: int = 256 * 1024 * 1024,
|
|
|
+ attn_cache_tokens: Optional[int] = None,
|
|
|
torch_dtype: str = "auto",
|
|
|
- revision: str = "main",
|
|
|
+ revision: Optional[str] = None,
|
|
|
cache_dir: Optional[str] = None,
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
- attn_cache_size: Optional[int] = None,
|
|
|
- alloc_timeout: float = 60,
|
|
|
+ alloc_timeout: float = 5,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
compression=CompressionType.NONE,
|
|
|
stats_report_interval: Optional[int] = None,
|
|
|
custom_module_path=None,
|
|
|
- update_period: float = 150,
|
|
|
+ update_period: float = 60,
|
|
|
expiration: Optional[float] = None,
|
|
|
request_timeout: float = 3 * 60,
|
|
|
session_timeout: float = 30 * 60,
|
|
@@ -73,34 +78,44 @@ class Server:
|
|
|
sender_threads: int = 1,
|
|
|
balance_quality: float = 0.75,
|
|
|
mean_balance_check_period: float = 120,
|
|
|
- mean_block_selection_delay: float = 2.5,
|
|
|
- use_auth_token: Optional[str] = None,
|
|
|
- load_in_8bit: Optional[bool] = None,
|
|
|
+ mean_block_selection_delay: float = 5,
|
|
|
+ token: Optional[Union[str, bool]] = None,
|
|
|
+ quant_type: Optional[QuantType] = None,
|
|
|
+ tensor_parallel_devices: Optional[Sequence[torch.device]] = None,
|
|
|
skip_reachability_check: bool = False,
|
|
|
+ reachable_via_relay: Optional[bool] = None,
|
|
|
+ use_relay: bool = True,
|
|
|
+ use_auto_relay: bool = True,
|
|
|
+ adapters: Sequence[str] = (),
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
|
+ converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)
|
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
+
|
|
|
self.num_handlers = num_handlers
|
|
|
- self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
- self.inference_max_length = inference_max_length
|
|
|
self.compression = compression
|
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
|
- self.use_auth_token = use_auth_token
|
|
|
+ self.revision, self.token = revision, token
|
|
|
|
|
|
if custom_module_path is not None:
|
|
|
add_custom_models_from_file(custom_module_path)
|
|
|
|
|
|
- if prefix is None:
|
|
|
- prefix = converted_model_name_or_path
|
|
|
- assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
|
|
- f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
|
|
|
- f"Please specify --prefix manually when starting a server"
|
|
|
- )
|
|
|
- logger.info(f"Automatic dht prefix: {prefix}")
|
|
|
- self.prefix = prefix
|
|
|
+ self.block_config = AutoDistributedConfig.from_pretrained(
|
|
|
+ converted_model_name_or_path,
|
|
|
+ use_auth_token=token,
|
|
|
+ revision=revision,
|
|
|
+ )
|
|
|
+
|
|
|
+ if dht_prefix is None:
|
|
|
+ dht_prefix = self.block_config.dht_prefix
|
|
|
+ assert UID_DELIMITER not in dht_prefix and CHAIN_DELIMITER not in dht_prefix, (
|
|
|
+ f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
|
|
|
+ f"Please specify another --dht_prefix manually when starting a server"
|
|
|
+ )
|
|
|
+ self.dht_prefix = dht_prefix
|
|
|
|
|
|
if expiration is None:
|
|
|
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
|
|
@@ -109,75 +124,127 @@ class Server:
|
|
|
self.request_timeout = request_timeout
|
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
|
|
- self.block_config = BloomConfig.from_pretrained(
|
|
|
- converted_model_name_or_path,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- revision=revision,
|
|
|
+ self.module_uids = [
|
|
|
+ f"{self.dht_prefix}{UID_DELIMITER}{block_index}"
|
|
|
+ for block_index in range(self.block_config.num_hidden_layers)
|
|
|
+ ]
|
|
|
+
|
|
|
+ if reachable_via_relay is None:
|
|
|
+ is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
|
|
|
+ reachable_via_relay = is_reachable is False # if can't check reachability (returns None), run a full peer
|
|
|
+ logger.info(f"This server is accessible {'via relays' if reachable_via_relay else 'directly'}")
|
|
|
+ self.dht = DHT(
|
|
|
+ initial_peers=initial_peers,
|
|
|
+ start=True,
|
|
|
+ num_workers=self.block_config.num_hidden_layers,
|
|
|
+ use_relay=use_relay,
|
|
|
+ use_auto_relay=use_auto_relay,
|
|
|
+ client_mode=reachable_via_relay,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
- self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
+ self.reachability_protocol = ReachabilityProtocol.attach_to_dht(self.dht) if not reachable_via_relay else None
|
|
|
|
|
|
- self.dht = DHT(initial_peers=initial_peers, start=True, num_workers=self.block_config.n_layer, **kwargs)
|
|
|
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
|
|
|
if initial_peers == PUBLIC_INITIAL_PEERS:
|
|
|
- logger.info(f"Connecting to the public swarm, peer_id = {self.dht.peer_id}")
|
|
|
- if not skip_reachability_check:
|
|
|
- self._check_reachability()
|
|
|
+ logger.info("Connecting to the public swarm")
|
|
|
else:
|
|
|
- logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
+ logger.info(f"Connecting to a private swarm, initial peers: {initial_peers}")
|
|
|
+ logger.info(f"Running a server on {visible_maddrs_str}")
|
|
|
+ self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
|
|
|
|
|
if device is None:
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
device = torch.device(device)
|
|
|
+ if device.type == "cuda" and device.index is None:
|
|
|
+ device = torch.device(device.type, index=0)
|
|
|
self.device = device
|
|
|
|
|
|
- if isinstance(torch_dtype, str):
|
|
|
- torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
- assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
+ torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
- if load_in_8bit is None:
|
|
|
- load_in_8bit = device.type == "cuda"
|
|
|
- if load_in_8bit:
|
|
|
- logger.info("Model weights will be loaded in 8-bit format")
|
|
|
- self.load_in_8bit = load_in_8bit
|
|
|
+ if tensor_parallel_devices is None:
|
|
|
+ tensor_parallel_devices = (device,)
|
|
|
+ self.tensor_parallel_devices = tuple(map(torch.device, tensor_parallel_devices))
|
|
|
+ if len(self.tensor_parallel_devices) > 1:
|
|
|
+ logger.info(f"Model weights will be split between {', '.join(tensor_parallel_devices)}")
|
|
|
+ check_device_balance(self.tensor_parallel_devices)
|
|
|
+
|
|
|
+ if quant_type is None:
|
|
|
+ if device.type == "cuda":
|
|
|
+ quant_type = QuantType.NF4 if self.block_config.model_type == "llama" else QuantType.INT8
|
|
|
+ else:
|
|
|
+ quant_type = QuantType.NONE
|
|
|
+ self.quant_type = quant_type
|
|
|
+ logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, quant_type)} format")
|
|
|
+
|
|
|
+ is_multiquery_attn = self.block_config.num_key_value_groups > 1
|
|
|
+ if max_batch_size is None:
|
|
|
+ max_batch_size = 8192 if is_multiquery_attn else 2048
|
|
|
+ if inference_max_length is None:
|
|
|
+ inference_max_length = 8192 if is_multiquery_attn else 2048
|
|
|
+ self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
+ self.inference_max_length = inference_max_length
|
|
|
+ self.max_chunk_size_bytes = max_chunk_size_bytes
|
|
|
+
|
|
|
+ # For attention cache in GPU or RAM
|
|
|
+ if attn_cache_tokens is None:
|
|
|
+ attn_cache_tokens = 32768 if is_multiquery_attn else 8192
|
|
|
+ cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
+ cache_values_per_block //= self.block_config.num_key_value_groups
|
|
|
+ self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
+
|
|
|
+ # For disk cache
|
|
|
+ self.cache_dir = cache_dir
|
|
|
+ self.max_disk_space = max_disk_space
|
|
|
+ self.adapters = adapters
|
|
|
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
|
if num_blocks is None and block_indices is None:
|
|
|
num_blocks = self._choose_num_blocks()
|
|
|
+ if num_blocks is not None:
|
|
|
+ num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
|
first_block_index, last_block_index = block_indices.split(":")
|
|
|
first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
|
|
|
except Exception as e:
|
|
|
- logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
- raise
|
|
|
+ raise ValueError(f"Failed to parse `--block_indices {block_indices}`, must be start:end (e.g. 0:18)")
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
num_blocks = len(block_indices)
|
|
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
|
|
|
gib = 1024**3
|
|
|
- if attn_cache_size is None:
|
|
|
- # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
|
|
- attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
|
- self.attn_cache_size, self.alloc_timeout = attn_cache_size, alloc_timeout
|
|
|
- logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
-
|
|
|
- if cache_dir is None:
|
|
|
- cache_dir = DEFAULT_CACHE_DIR
|
|
|
- self.cache_dir = cache_dir
|
|
|
- self.max_disk_space = max_disk_space
|
|
|
+ self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
|
|
+ logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
|
|
+
|
|
|
+ self.alloc_timeout = alloc_timeout
|
|
|
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
if throughput in ["auto", "eval"]:
|
|
|
- throughput = get_host_throughput(
|
|
|
+ throughput_info = get_server_throughput(
|
|
|
+ converted_model_name_or_path,
|
|
|
self.block_config,
|
|
|
device,
|
|
|
torch_dtype,
|
|
|
- load_in_8bit=load_in_8bit,
|
|
|
+ num_blocks=num_blocks,
|
|
|
+ quant_type=quant_type,
|
|
|
+ tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
+ reachable_via_relay=reachable_via_relay,
|
|
|
force_eval=(throughput == "eval"),
|
|
|
cache_dir=cache_dir,
|
|
|
)
|
|
|
- self.throughput = throughput
|
|
|
+ else:
|
|
|
+ throughput_info = {"throughput": throughput}
|
|
|
+ self.server_info = ServerInfo(
|
|
|
+ state=ServerState.JOINING,
|
|
|
+ public_name=public_name,
|
|
|
+ version=petals.__version__,
|
|
|
+ adapters=tuple(adapters),
|
|
|
+ torch_dtype=str(torch_dtype).replace("torch.", ""),
|
|
|
+ quant_type=quant_type.name.lower(),
|
|
|
+ using_relay=reachable_via_relay,
|
|
|
+ **throughput_info,
|
|
|
+ )
|
|
|
|
|
|
self.balance_quality = balance_quality
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
@@ -185,65 +252,72 @@ class Server:
|
|
|
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
- def _check_reachability(self):
|
|
|
- try:
|
|
|
- r = requests.get(f"http://health.petals.ml/api/v1/is_reachable/{self.dht.peer_id}", timeout=10)
|
|
|
- r.raise_for_status()
|
|
|
- response = r.json()
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"Skipping reachability check because health.petals.ml is down: {repr(e)}")
|
|
|
- return
|
|
|
-
|
|
|
- if not response["success"]:
|
|
|
- # This happens only if health.petals.ml is up and explicitly told us that we are unreachable
|
|
|
- raise RuntimeError(
|
|
|
- f"Server is not reachable from the Internet:\n\n"
|
|
|
- f"{response['message']}\n\n"
|
|
|
- f"You need to fix your port forwarding and/or firewall settings. How to do that:\n\n"
|
|
|
- f" 1. Choose a specific port for the Petals server, for example, 31337.\n"
|
|
|
- f" 2. Ensure that this port is accessible from the Internet and not blocked by your firewall.\n"
|
|
|
- f" 3. Add these arguments to explicitly announce your IP address and port to other peers:\n"
|
|
|
- f" python -m petals.cli.run_server ... --public_ip {response['your_ip']} --port 31337\n"
|
|
|
- f" 4. If it does not help, ask for help in our Discord: https://discord.gg/Wuk8BnrEPH\n"
|
|
|
+ def _choose_num_blocks(self) -> int:
|
|
|
+ assert self.device.type == "cuda", (
|
|
|
+ "GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
|
|
+ "CPU-only servers in the public swarm are discouraged since they are much slower"
|
|
|
+ )
|
|
|
+ num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
|
|
+
|
|
|
+ if num_devices > 1:
|
|
|
+ memory_per_device = tuple(
|
|
|
+ torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
|
|
|
)
|
|
|
+ total_memory = min(memory_per_device) * num_devices
|
|
|
+ if max(memory_per_device) / min(memory_per_device) > 1.5:
|
|
|
+ raise ValueError(
|
|
|
+ "GPU devices have highly uneven memory, which makes tensor parallelism inefficient. "
|
|
|
+ "Please launch individual servers on each GPU or set --num_blocks manually to "
|
|
|
+ "override this exception."
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
|
|
|
- logger.info("Server is reachable from the Internet, it will appear at http://health.petals.ml soon")
|
|
|
+ gib = 1024**3
|
|
|
+ # Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
|
|
+ autograd_memory = 2 * gib * num_devices / 14336 * self.block_config.hidden_size
|
|
|
|
|
|
- def _choose_num_blocks(self) -> int:
|
|
|
- assert (
|
|
|
- self.converted_model_name_or_path == "bigscience/bloom-petals"
|
|
|
- ), "If you use a model other than bigscience/bloom-petals, please specify --num_blocks manually"
|
|
|
- assert self.device.type == "cuda", "If you run a non-GPU server, please specify --num_blocks manually"
|
|
|
+ block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, quant_type=self.quant_type)
|
|
|
+ total_memory_per_block = block_size + self._cache_bytes_per_block
|
|
|
+ if self.adapters:
|
|
|
+ # Delay import of petals.utils.peft to avoid unnecessary import of bitsandbytes
|
|
|
+ from petals.utils.peft import estimate_adapter_memory_per_block
|
|
|
|
|
|
- total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
- block_size = get_block_size(self.block_config, "memory", dtype=self.torch_dtype, load_in_8bit=self.load_in_8bit)
|
|
|
- gib = 1024**3
|
|
|
- attn_cache_per_block = 0.5 * gib # TODO: This does not account for manually set --attn_cache_size
|
|
|
+ total_memory_per_block += estimate_adapter_memory_per_block(
|
|
|
+ self.block_config,
|
|
|
+ self.torch_dtype,
|
|
|
+ self.adapters,
|
|
|
+ token=self.token,
|
|
|
+ cache_dir=self.cache_dir,
|
|
|
+ max_disk_space=self.max_disk_space,
|
|
|
+ )
|
|
|
|
|
|
- num_blocks = math.floor((total_memory - 2 * gib) / (block_size + attn_cache_per_block))
|
|
|
+ num_blocks = math.floor((total_memory - autograd_memory) / total_memory_per_block)
|
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
|
|
+ num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
|
|
|
logger.info(
|
|
|
- f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
|
|
+ f"Server will fill your GPU memory with {num_blocks} transformer blocks. "
|
|
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
|
|
)
|
|
|
- return min(num_blocks, self.block_config.n_layer)
|
|
|
+ return num_blocks
|
|
|
|
|
|
def run(self):
|
|
|
while True:
|
|
|
block_indices = self._choose_blocks()
|
|
|
self.module_container = ModuleContainer.create(
|
|
|
dht=self.dht,
|
|
|
- prefix=self.prefix,
|
|
|
+ dht_prefix=self.dht_prefix,
|
|
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
block_config=self.block_config,
|
|
|
- attn_cache_size=self.attn_cache_size,
|
|
|
+ attn_cache_bytes=self.attn_cache_bytes,
|
|
|
alloc_timeout=self.alloc_timeout,
|
|
|
- throughput=self.throughput,
|
|
|
+ server_info=self.server_info,
|
|
|
block_indices=block_indices,
|
|
|
num_handlers=self.num_handlers,
|
|
|
min_batch_size=self.min_batch_size,
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
+ max_chunk_size_bytes=self.max_chunk_size_bytes,
|
|
|
inference_max_length=self.inference_max_length,
|
|
|
torch_dtype=self.torch_dtype,
|
|
|
cache_dir=self.cache_dir,
|
|
@@ -258,8 +332,11 @@ class Server:
|
|
|
step_timeout=self.step_timeout,
|
|
|
prefetch_batches=self.prefetch_batches,
|
|
|
sender_threads=self.sender_threads,
|
|
|
- use_auth_token=self.use_auth_token,
|
|
|
- load_in_8bit=self.load_in_8bit,
|
|
|
+ revision=self.revision,
|
|
|
+ token=self.token,
|
|
|
+ quant_type=self.quant_type,
|
|
|
+ tensor_parallel_devices=self.tensor_parallel_devices,
|
|
|
+ should_validate_reachability=self.should_validate_reachability,
|
|
|
start=True,
|
|
|
)
|
|
|
try:
|
|
@@ -286,10 +363,6 @@ class Server:
|
|
|
del self.module_container
|
|
|
gc.collect() # In particular, this closes unused file descriptors
|
|
|
|
|
|
- cur_proc = psutil.Process()
|
|
|
- num_fds = [proc.num_fds() for proc in [cur_proc] + cur_proc.children(recursive=True)]
|
|
|
- logger.info(f"Cleaning up, left {sum(num_fds)} open file descriptors")
|
|
|
-
|
|
|
if self.device.type == "cuda":
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
@@ -308,19 +381,21 @@ class Server:
|
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|
|
|
time.sleep(random.random() * 2 * self.mean_block_selection_delay)
|
|
|
- module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
|
|
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
|
|
|
return block_selection.choose_best_blocks(self.num_blocks, module_infos)
|
|
|
|
|
|
def _should_choose_other_blocks(self) -> bool:
|
|
|
if self.strict_block_indices is not None:
|
|
|
return False
|
|
|
|
|
|
- module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
|
|
|
+ module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
|
|
|
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
|
|
|
|
|
|
def shutdown(self):
|
|
|
self.stop.set()
|
|
|
|
|
|
+ if self.reachability_protocol is not None:
|
|
|
+ self.reachability_protocol.shutdown()
|
|
|
self.dht.shutdown()
|
|
|
self.dht.join()
|
|
|
|
|
@@ -334,15 +409,16 @@ class ModuleContainer(threading.Thread):
|
|
|
cls,
|
|
|
*,
|
|
|
dht: DHT,
|
|
|
- prefix: str,
|
|
|
+ dht_prefix: str,
|
|
|
converted_model_name_or_path: str,
|
|
|
- block_config: BloomConfig,
|
|
|
- attn_cache_size: int,
|
|
|
+ block_config: PretrainedConfig,
|
|
|
+ attn_cache_bytes: int,
|
|
|
alloc_timeout: float,
|
|
|
- throughput: float,
|
|
|
+ server_info: ServerInfo,
|
|
|
block_indices: List[int],
|
|
|
min_batch_size: int,
|
|
|
max_batch_size: int,
|
|
|
+ max_chunk_size_bytes: int,
|
|
|
torch_dtype: torch.dtype,
|
|
|
cache_dir: str,
|
|
|
max_disk_space: int,
|
|
@@ -350,89 +426,99 @@ class ModuleContainer(threading.Thread):
|
|
|
compression: CompressionType,
|
|
|
update_period: float,
|
|
|
expiration: Optional[float],
|
|
|
- use_auth_token: Optional[str],
|
|
|
- load_in_8bit: bool,
|
|
|
+ revision: Optional[str],
|
|
|
+ token: Optional[Union[str, bool]],
|
|
|
+ quant_type: QuantType,
|
|
|
+ tensor_parallel_devices: Sequence[torch.device],
|
|
|
+ should_validate_reachability: bool,
|
|
|
**kwargs,
|
|
|
) -> ModuleContainer:
|
|
|
- module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
|
|
|
- joining_announcer = ModuleAnnouncerThread(
|
|
|
+ module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
|
|
+ memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
+
|
|
|
+ server_info.state = ServerState.JOINING
|
|
|
+ dht_announcer = ModuleAnnouncerThread(
|
|
|
module_uids,
|
|
|
dht,
|
|
|
- ServerState.JOINING,
|
|
|
- throughput=throughput,
|
|
|
+ server_info,
|
|
|
+ block_config=block_config,
|
|
|
+ memory_cache=memory_cache,
|
|
|
update_period=update_period,
|
|
|
expiration=expiration,
|
|
|
daemon=True,
|
|
|
)
|
|
|
- joining_announcer.start()
|
|
|
+ dht_announcer.start()
|
|
|
logger.info(f"Announced that blocks {block_indices} are joining")
|
|
|
|
|
|
- memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
+ assert len(tensor_parallel_devices) >= 1 and all(isinstance(d, torch.device) for d in tensor_parallel_devices)
|
|
|
+
|
|
|
blocks = {}
|
|
|
try:
|
|
|
for module_uid, block_index in zip(module_uids, block_indices):
|
|
|
block = load_pretrained_block(
|
|
|
converted_model_name_or_path,
|
|
|
block_index,
|
|
|
- block_config,
|
|
|
+ config=block_config,
|
|
|
torch_dtype=torch_dtype,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
+ revision=revision,
|
|
|
+ token=token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ max_disk_space=max_disk_space,
|
|
|
+ )
|
|
|
+ block = convert_block(
|
|
|
+ block,
|
|
|
+ block_index,
|
|
|
+ block_config,
|
|
|
+ tensor_parallel_devices,
|
|
|
+ device,
|
|
|
+ quant_type,
|
|
|
+ adapters=server_info.adapters,
|
|
|
+ freeze=True,
|
|
|
+ token=token,
|
|
|
cache_dir=cache_dir,
|
|
|
max_disk_space=max_disk_space,
|
|
|
)
|
|
|
-
|
|
|
- if load_in_8bit:
|
|
|
- block = replace_8bit_linear(block)
|
|
|
-
|
|
|
- block = block.to(device)
|
|
|
- for param in block.parameters():
|
|
|
- param.requires_grad = False
|
|
|
-
|
|
|
- backend_dtype = block.input_layernorm.weight.dtype if torch_dtype == "auto" else torch_dtype
|
|
|
blocks[module_uid] = TransformerBackend(
|
|
|
module_uid,
|
|
|
block,
|
|
|
+ config=block_config,
|
|
|
memory_cache=memory_cache,
|
|
|
- backend_dtype=backend_dtype,
|
|
|
+ backend_dtype=torch_dtype,
|
|
|
+ max_chunk_size_bytes=max_chunk_size_bytes,
|
|
|
args_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
kwargs_schema={},
|
|
|
outputs_schema=(
|
|
|
BatchTensorDescriptor(
|
|
|
- 1, 2048, block_config.hidden_size, dtype=backend_dtype, compression=compression
|
|
|
+ 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
|
|
|
),
|
|
|
),
|
|
|
min_batch_size=min_batch_size,
|
|
|
max_batch_size=max_batch_size,
|
|
|
)
|
|
|
+
|
|
|
+ merge_inference_pools_inplace(blocks)
|
|
|
+
|
|
|
+ if should_validate_reachability:
|
|
|
+ validate_reachability(dht.peer_id)
|
|
|
except:
|
|
|
logger.debug("Shutting down backends")
|
|
|
for backend in blocks.values():
|
|
|
backend.shutdown()
|
|
|
|
|
|
- joining_announcer.stop.set()
|
|
|
- joining_announcer.join()
|
|
|
- declare_active_modules(
|
|
|
- dht,
|
|
|
- module_uids,
|
|
|
- expiration_time=get_dht_time() + expiration,
|
|
|
- state=ServerState.OFFLINE,
|
|
|
- throughput=throughput,
|
|
|
- )
|
|
|
+ dht_announcer.announce(ServerState.OFFLINE)
|
|
|
logger.info(f"Announced that blocks {module_uids} are offline")
|
|
|
raise
|
|
|
- else:
|
|
|
- joining_announcer.stop.set()
|
|
|
- joining_announcer.join()
|
|
|
|
|
|
return cls(
|
|
|
dht,
|
|
|
+ dht_prefix,
|
|
|
blocks,
|
|
|
- throughput=throughput,
|
|
|
- device=device,
|
|
|
+ dht_announcer=dht_announcer,
|
|
|
+ server_info=server_info,
|
|
|
update_period=update_period,
|
|
|
expiration=expiration,
|
|
|
**kwargs,
|
|
@@ -441,11 +527,13 @@ class ModuleContainer(threading.Thread):
|
|
|
def __init__(
|
|
|
self,
|
|
|
dht: DHT,
|
|
|
+ dht_prefix: str,
|
|
|
module_backends: Dict[str, TransformerBackend],
|
|
|
*,
|
|
|
inference_max_length: int,
|
|
|
num_handlers: int,
|
|
|
- throughput: float,
|
|
|
+ dht_announcer: ModuleAnnouncerThread,
|
|
|
+ server_info: ServerInfo,
|
|
|
update_period: float,
|
|
|
expiration: Optional[float] = None,
|
|
|
request_timeout: float,
|
|
@@ -457,29 +545,31 @@ class ModuleContainer(threading.Thread):
|
|
|
super().__init__()
|
|
|
|
|
|
self.dht, self.module_backends = dht, module_backends
|
|
|
- self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
|
|
|
+ self.server_info, self.update_period, self.expiration = server_info, update_period, expiration
|
|
|
+
|
|
|
+ handler_event_queues = [mp.Queue() for _ in range(num_handlers)]
|
|
|
self.conn_handlers = [
|
|
|
TransformerConnectionHandler(
|
|
|
dht,
|
|
|
self.module_backends,
|
|
|
+ adapters=server_info.adapters,
|
|
|
+ dht_prefix=dht_prefix,
|
|
|
+ handler_event_queues=handler_event_queues,
|
|
|
+ handler_index=i,
|
|
|
inference_max_length=inference_max_length,
|
|
|
request_timeout=request_timeout,
|
|
|
session_timeout=session_timeout,
|
|
|
step_timeout=step_timeout,
|
|
|
+ quant_type=QuantType[server_info.quant_type.upper()],
|
|
|
)
|
|
|
- for _ in range(num_handlers)
|
|
|
+ for i in range(num_handlers)
|
|
|
]
|
|
|
- self.runtime = Runtime(self.module_backends, **kwargs)
|
|
|
- self.online_announcer = ModuleAnnouncerThread(
|
|
|
- list(self.module_backends.keys()),
|
|
|
- dht,
|
|
|
- ServerState.ONLINE,
|
|
|
- throughput=throughput,
|
|
|
- update_period=update_period,
|
|
|
- expiration=expiration,
|
|
|
- daemon=True,
|
|
|
- )
|
|
|
- self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
|
|
|
+
|
|
|
+ self.runtime = RuntimeWithDeduplicatedPools(self.module_backends, device=None, **kwargs)
|
|
|
+ # note: We set device=None in runtime to avoid moving all modules to device 0 in runtime.run(). tensor_parallel has already moved it as needed.
|
|
|
+
|
|
|
+ dht_announcer.announce(ServerState.ONLINE)
|
|
|
+ self.dht_announcer = dht_announcer
|
|
|
|
|
|
if start:
|
|
|
self.run_in_background(await_ready=True)
|
|
@@ -489,14 +579,6 @@ class ModuleContainer(threading.Thread):
|
|
|
Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
|
|
|
runs Runtime (self.runtime) to process incoming requests.
|
|
|
"""
|
|
|
- if not self.dht.is_alive():
|
|
|
- self.dht.run_in_background(await_ready=True)
|
|
|
-
|
|
|
- self.online_announcer.start()
|
|
|
-
|
|
|
- if self.checkpoint_saver is not None:
|
|
|
- self.checkpoint_saver.start()
|
|
|
-
|
|
|
for handler in self.conn_handlers:
|
|
|
handler.run_in_background()
|
|
|
|
|
@@ -535,27 +617,14 @@ class ModuleContainer(threading.Thread):
|
|
|
Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
|
|
|
If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
|
|
|
"""
|
|
|
- self.online_announcer.stop.set()
|
|
|
- self.online_announcer.join()
|
|
|
-
|
|
|
- declare_active_modules(
|
|
|
- self.dht,
|
|
|
- self.module_backends.keys(),
|
|
|
- expiration_time=get_dht_time() + self.expiration,
|
|
|
- state=ServerState.OFFLINE,
|
|
|
- throughput=self.throughput,
|
|
|
- )
|
|
|
+ self.dht_announcer.announce(ServerState.OFFLINE)
|
|
|
logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
|
|
|
|
|
|
self.ready.clear()
|
|
|
|
|
|
+ logger.debug("Shutting down connection handlers")
|
|
|
for handler in self.conn_handlers:
|
|
|
handler.shutdown()
|
|
|
- logger.debug("Connection handlers terminated")
|
|
|
-
|
|
|
- if self.checkpoint_saver is not None:
|
|
|
- self.checkpoint_saver.stop.set()
|
|
|
- self.checkpoint_saver.join()
|
|
|
|
|
|
logger.debug(f"Shutting down pools")
|
|
|
for pool in self.runtime.pools:
|
|
@@ -579,30 +648,85 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
self,
|
|
|
module_uids: List[str],
|
|
|
dht: DHT,
|
|
|
- state: ServerState,
|
|
|
+ server_info: ServerInfo,
|
|
|
*,
|
|
|
- throughput: float,
|
|
|
- update_period: float = 30,
|
|
|
+ block_config: PretrainedConfig,
|
|
|
+ memory_cache: MemoryCache,
|
|
|
+ update_period: float,
|
|
|
expiration: float,
|
|
|
+ max_pinged: int = 5,
|
|
|
**kwargs,
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.module_uids = module_uids
|
|
|
self.dht = dht
|
|
|
- self.state = state
|
|
|
- self.throughput = throughput
|
|
|
+ self.server_info = server_info
|
|
|
+ self.memory_cache = memory_cache
|
|
|
+
|
|
|
+ self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
|
|
+ self.bytes_per_token //= block_config.num_key_value_groups
|
|
|
+
|
|
|
self.update_period = update_period
|
|
|
self.expiration = expiration
|
|
|
- self.stop = threading.Event()
|
|
|
+ self.trigger = threading.Event()
|
|
|
+
|
|
|
+ self.max_pinged = max_pinged
|
|
|
+ dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
|
|
+ block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
|
|
|
+ start_block, end_block = min(block_indices), max(block_indices) + 1
|
|
|
+ self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
|
|
+ self.ping_aggregator = PingAggregator(self.dht)
|
|
|
|
|
|
def run(self) -> None:
|
|
|
while True:
|
|
|
+ start_time = time.perf_counter()
|
|
|
+
|
|
|
+ self.server_info.cache_tokens_left = self.memory_cache.bytes_left // self.bytes_per_token
|
|
|
+ if self.server_info.state != ServerState.OFFLINE:
|
|
|
+ self._ping_next_servers()
|
|
|
+ self.server_info.next_pings = {
|
|
|
+ peer_id.to_base58(): rtt for peer_id, rtt in self.ping_aggregator.to_dict().items()
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ self.server_info.next_pings = None # No need to ping if we're disconnecting
|
|
|
+
|
|
|
declare_active_modules(
|
|
|
self.dht,
|
|
|
self.module_uids,
|
|
|
+ self.server_info,
|
|
|
expiration_time=get_dht_time() + self.expiration,
|
|
|
- state=self.state,
|
|
|
- throughput=self.throughput,
|
|
|
)
|
|
|
- if self.stop.wait(self.update_period):
|
|
|
+ if self.server_info.state == ServerState.OFFLINE:
|
|
|
break
|
|
|
+
|
|
|
+ delay = self.update_period - (time.perf_counter() - start_time)
|
|
|
+ if delay < 0:
|
|
|
+ logger.warning(
|
|
|
+ f"Declaring blocks to DHT takes more than --update_period, consider increasing it (currently {self.update_period})"
|
|
|
+ )
|
|
|
+ self.trigger.wait(max(delay, 0))
|
|
|
+ self.trigger.clear()
|
|
|
+
|
|
|
+ def announce(self, state: ServerState) -> None:
|
|
|
+ self.server_info.state = state
|
|
|
+ self.trigger.set()
|
|
|
+ if state == ServerState.OFFLINE:
|
|
|
+ self.join()
|
|
|
+
|
|
|
+ def _ping_next_servers(self) -> Dict[hivemind.PeerID, float]:
|
|
|
+ module_infos = get_remote_module_infos(self.dht, self.next_uids, latest=True)
|
|
|
+ middle_servers = {peer_id for info in module_infos[:-1] if info is not None for peer_id in info.servers}
|
|
|
+ pinged_servers = set(sample_up_to(middle_servers, self.max_pinged))
|
|
|
+ pinged_servers.discard(self.dht.peer_id)
|
|
|
+ if module_infos[-1] is not None:
|
|
|
+ # Sample servers hosting the block after the last one (most likely continuations) separately
|
|
|
+ pinged_servers |= set(sample_up_to(module_infos[-1].servers, self.max_pinged))
|
|
|
+ self.ping_aggregator.ping(list(pinged_servers))
|
|
|
+
|
|
|
+
|
|
|
+class RuntimeWithDeduplicatedPools(Runtime):
|
|
|
+ """A version of hivemind.moe.server.runtime.Runtime that allows multiple backends to reuse a task pool"""
|
|
|
+
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ self.pools = tuple(set(self.pools))
|