|
@@ -3,13 +3,16 @@ from __future__ import annotations
|
|
|
import gc
|
|
|
import math
|
|
|
import multiprocessing as mp
|
|
|
+import os
|
|
|
import random
|
|
|
import threading
|
|
|
import time
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
|
import hivemind
|
|
|
+import psutil
|
|
|
import torch
|
|
|
+import torch.mps
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
@@ -19,7 +22,7 @@ from transformers import PretrainedConfig
|
|
|
|
|
|
import petals
|
|
|
from petals.constants import DTYPE_MAP, PUBLIC_INITIAL_PEERS
|
|
|
-from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerInfo, ServerState
|
|
|
+from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModelInfo, ServerInfo, ServerState
|
|
|
from petals.server import block_selection
|
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
|
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
@@ -31,6 +34,7 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
from petals.utils.auto_config import AutoDistributedConfig
|
|
|
from petals.utils.convert_block import QuantType, check_device_balance, convert_block
|
|
|
from petals.utils.dht import declare_active_modules, get_remote_module_infos
|
|
|
+from petals.utils.misc import get_size_in_bytes
|
|
|
from petals.utils.ping import PingAggregator
|
|
|
from petals.utils.random import sample_up_to
|
|
|
from petals.utils.version import get_compatible_model_repo
|
|
@@ -59,12 +63,12 @@ class Server:
|
|
|
min_batch_size: int = 1,
|
|
|
max_batch_size: Optional[int] = None,
|
|
|
max_chunk_size_bytes: int = 256 * 1024 * 1024,
|
|
|
+ max_alloc_timeout: float = 600,
|
|
|
attn_cache_tokens: Optional[int] = None,
|
|
|
torch_dtype: str = "auto",
|
|
|
revision: Optional[str] = None,
|
|
|
cache_dir: Optional[str] = None,
|
|
|
max_disk_space: Optional[int] = None,
|
|
|
- alloc_timeout: float = 5,
|
|
|
device: Optional[Union[str, torch.device]] = None,
|
|
|
compression=CompressionType.NONE,
|
|
|
stats_report_interval: Optional[int] = None,
|
|
@@ -153,13 +157,25 @@ class Server:
|
|
|
self.should_validate_reachability = not skip_reachability_check and initial_peers == PUBLIC_INITIAL_PEERS
|
|
|
|
|
|
if device is None:
|
|
|
- device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ device = "cuda"
|
|
|
+ elif torch.backends.mps.is_available():
|
|
|
+ device = "mps"
|
|
|
+ else:
|
|
|
+ device = "cpu"
|
|
|
device = torch.device(device)
|
|
|
if device.type == "cuda" and device.index is None:
|
|
|
device = torch.device(device.type, index=0)
|
|
|
self.device = device
|
|
|
|
|
|
torch_dtype = resolve_block_dtype(self.block_config, DTYPE_MAP[torch_dtype])
|
|
|
+ if device.type == "cpu" and torch_dtype == torch.float16:
|
|
|
+ raise ValueError(
|
|
|
+ f"Type float16 is not supported on CPU. Please use --torch_dtype float32 or --torch_dtype bfloat16"
|
|
|
+ )
|
|
|
+ if device.type == "mps" and torch_dtype == torch.bfloat16:
|
|
|
+ logger.warning(f"Type bfloat16 is not supported on MPS, using float16 instead")
|
|
|
+ torch_dtype = torch.float16
|
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
if tensor_parallel_devices is None:
|
|
@@ -185,13 +201,14 @@ class Server:
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
self.inference_max_length = inference_max_length
|
|
|
self.max_chunk_size_bytes = max_chunk_size_bytes
|
|
|
+ self.max_alloc_timeout = max_alloc_timeout
|
|
|
|
|
|
# For attention cache in GPU or RAM
|
|
|
if attn_cache_tokens is None:
|
|
|
attn_cache_tokens = 32768 if is_multiquery_attn else 8192
|
|
|
cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
cache_values_per_block //= self.block_config.num_key_value_groups
|
|
|
- self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
+ self._cache_bytes_per_block = cache_values_per_block * get_size_in_bytes(self.torch_dtype)
|
|
|
|
|
|
# For disk cache
|
|
|
self.cache_dir = cache_dir
|
|
@@ -217,8 +234,6 @@ class Server:
|
|
|
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
|
|
|
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
|
|
|
|
|
|
- self.alloc_timeout = alloc_timeout
|
|
|
-
|
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
if throughput in ["auto", "eval"]:
|
|
|
throughput_info = get_server_throughput(
|
|
@@ -245,21 +260,26 @@ class Server:
|
|
|
using_relay=reachable_via_relay,
|
|
|
**throughput_info,
|
|
|
)
|
|
|
+ self.model_info = ModelInfo(num_blocks=self.block_config.num_hidden_layers)
|
|
|
+ if not os.path.isdir(converted_model_name_or_path):
|
|
|
+ self.model_info.repository = "https://huggingface.co/" + converted_model_name_or_path
|
|
|
|
|
|
self.balance_quality = balance_quality
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
|
|
|
|
+ self.module_container = None
|
|
|
self.stop = threading.Event()
|
|
|
|
|
|
def _choose_num_blocks(self) -> int:
|
|
|
- assert self.device.type == "cuda", (
|
|
|
+ assert self.device.type in ("cuda", "mps"), (
|
|
|
"GPU is not available. If you want to run a CPU-only server, please specify --num_blocks. "
|
|
|
"CPU-only servers in the public swarm are discouraged since they are much slower"
|
|
|
)
|
|
|
num_devices = len(self.tensor_parallel_devices) if self.tensor_parallel_devices else 1
|
|
|
|
|
|
if num_devices > 1:
|
|
|
+ assert self.device.type == "cuda", f"Tensor parallelism is not supported on {self.device.type.upper()}"
|
|
|
memory_per_device = tuple(
|
|
|
torch.cuda.get_device_properties(device).total_memory for device in self.tensor_parallel_devices
|
|
|
)
|
|
@@ -270,8 +290,10 @@ class Server:
|
|
|
"Please launch individual servers on each GPU or set --num_blocks manually to "
|
|
|
"override this exception."
|
|
|
)
|
|
|
- else:
|
|
|
+ elif self.device.type == "cuda":
|
|
|
total_memory = torch.cuda.get_device_properties(self.device).total_memory
|
|
|
+ else:
|
|
|
+ total_memory = psutil.virtual_memory().total
|
|
|
|
|
|
gib = 1024**3
|
|
|
# Estimate of GPU memory used in rpc_backward (2 GiB for BLOOM, proportional for other models)
|
|
@@ -311,13 +333,14 @@ class Server:
|
|
|
converted_model_name_or_path=self.converted_model_name_or_path,
|
|
|
block_config=self.block_config,
|
|
|
attn_cache_bytes=self.attn_cache_bytes,
|
|
|
- alloc_timeout=self.alloc_timeout,
|
|
|
server_info=self.server_info,
|
|
|
+ model_info=self.model_info,
|
|
|
block_indices=block_indices,
|
|
|
num_handlers=self.num_handlers,
|
|
|
min_batch_size=self.min_batch_size,
|
|
|
max_batch_size=self.max_batch_size,
|
|
|
max_chunk_size_bytes=self.max_chunk_size_bytes,
|
|
|
+ max_alloc_timeout=self.max_alloc_timeout,
|
|
|
inference_max_length=self.inference_max_length,
|
|
|
torch_dtype=self.torch_dtype,
|
|
|
cache_dir=self.cache_dir,
|
|
@@ -360,7 +383,7 @@ class Server:
|
|
|
self._clean_memory_and_fds()
|
|
|
|
|
|
def _clean_memory_and_fds(self):
|
|
|
- del self.module_container
|
|
|
+ self.module_container = None
|
|
|
gc.collect() # In particular, this closes unused file descriptors
|
|
|
|
|
|
if self.device.type == "cuda":
|
|
@@ -373,6 +396,8 @@ class Server:
|
|
|
f"Cleaning up, left {allocated_vram / gib:.1f} GiB allocated memory, "
|
|
|
f"{reserved_vram / gib:.1f} GiB reserved memory"
|
|
|
)
|
|
|
+ elif self.device.type == "mps":
|
|
|
+ torch.mps.empty_cache()
|
|
|
|
|
|
def _choose_blocks(self) -> List[int]:
|
|
|
if self.strict_block_indices is not None:
|
|
@@ -391,8 +416,10 @@ class Server:
|
|
|
module_infos = get_remote_module_infos(self.dht, self.module_uids, latest=True)
|
|
|
return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
|
|
|
|
|
|
- def shutdown(self):
|
|
|
+ def shutdown(self, timeout: Optional[float] = 5):
|
|
|
self.stop.set()
|
|
|
+ if self.module_container is not None and self.module_container.is_alive():
|
|
|
+ self.module_container.join(timeout)
|
|
|
|
|
|
if self.reachability_protocol is not None:
|
|
|
self.reachability_protocol.shutdown()
|
|
@@ -413,12 +440,13 @@ class ModuleContainer(threading.Thread):
|
|
|
converted_model_name_or_path: str,
|
|
|
block_config: PretrainedConfig,
|
|
|
attn_cache_bytes: int,
|
|
|
- alloc_timeout: float,
|
|
|
server_info: ServerInfo,
|
|
|
+ model_info: ModelInfo,
|
|
|
block_indices: List[int],
|
|
|
min_batch_size: int,
|
|
|
max_batch_size: int,
|
|
|
max_chunk_size_bytes: int,
|
|
|
+ max_alloc_timeout: float,
|
|
|
torch_dtype: torch.dtype,
|
|
|
cache_dir: str,
|
|
|
max_disk_space: int,
|
|
@@ -434,13 +462,14 @@ class ModuleContainer(threading.Thread):
|
|
|
**kwargs,
|
|
|
) -> ModuleContainer:
|
|
|
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
|
|
|
- memory_cache = MemoryCache(attn_cache_bytes, alloc_timeout)
|
|
|
+ memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
|
|
|
|
|
|
server_info.state = ServerState.JOINING
|
|
|
dht_announcer = ModuleAnnouncerThread(
|
|
|
module_uids,
|
|
|
dht,
|
|
|
server_info,
|
|
|
+ model_info,
|
|
|
block_config=block_config,
|
|
|
memory_cache=memory_cache,
|
|
|
update_period=update_period,
|
|
@@ -649,6 +678,7 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
module_uids: List[str],
|
|
|
dht: DHT,
|
|
|
server_info: ServerInfo,
|
|
|
+ model_info: ModelInfo,
|
|
|
*,
|
|
|
block_config: PretrainedConfig,
|
|
|
memory_cache: MemoryCache,
|
|
@@ -661,9 +691,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
self.module_uids = module_uids
|
|
|
self.dht = dht
|
|
|
self.server_info = server_info
|
|
|
+ self.model_info = model_info
|
|
|
self.memory_cache = memory_cache
|
|
|
|
|
|
- self.bytes_per_token = block_config.hidden_size * torch.finfo(DTYPE_MAP[server_info.torch_dtype]).bits // 8
|
|
|
+ self.bytes_per_token = block_config.hidden_size * get_size_in_bytes(DTYPE_MAP[server_info.torch_dtype])
|
|
|
self.bytes_per_token //= block_config.num_key_value_groups
|
|
|
|
|
|
self.update_period = update_period
|
|
@@ -671,10 +702,10 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
self.trigger = threading.Event()
|
|
|
|
|
|
self.max_pinged = max_pinged
|
|
|
- dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
|
|
+ self.dht_prefix = module_uids[0].split(UID_DELIMITER)[0]
|
|
|
block_indices = [int(uid.split(UID_DELIMITER)[-1]) for uid in module_uids]
|
|
|
start_block, end_block = min(block_indices), max(block_indices) + 1
|
|
|
- self.next_uids = [f"{dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
|
|
+ self.next_uids = [f"{self.dht_prefix}{UID_DELIMITER}{i}" for i in range(start_block + 1, end_block + 1)]
|
|
|
self.ping_aggregator = PingAggregator(self.dht)
|
|
|
|
|
|
def run(self) -> None:
|
|
@@ -698,6 +729,13 @@ class ModuleAnnouncerThread(threading.Thread):
|
|
|
)
|
|
|
if self.server_info.state == ServerState.OFFLINE:
|
|
|
break
|
|
|
+ if not self.dht_prefix.startswith("_"): # Not private
|
|
|
+ self.dht.store(
|
|
|
+ key="_petals.models",
|
|
|
+ subkey=self.dht_prefix,
|
|
|
+ value=self.model_info.to_dict(),
|
|
|
+ expiration_time=get_dht_time() + self.expiration,
|
|
|
+ )
|
|
|
|
|
|
delay = self.update_period - (time.perf_counter() - start_time)
|
|
|
if delay < 0:
|