|
@@ -14,21 +14,23 @@ from hivemind.moe.server.layers import add_custom_models_from_file
|
|
from hivemind.moe.server.runtime import Runtime
|
|
from hivemind.moe.server.runtime import Runtime
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
from hivemind.utils.logging import get_logger
|
|
from hivemind.utils.logging import get_logger
|
|
-from transformers import BloomConfig
|
|
|
|
|
|
+from transformers import PretrainedConfig
|
|
|
|
|
|
-from petals.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
|
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
from petals.constants import PUBLIC_INITIAL_PEERS
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
from petals.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
from petals.dht_utils import declare_active_modules, get_remote_module_infos
|
|
from petals.server import block_selection
|
|
from petals.server import block_selection
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
from petals.server.backend import TransformerBackend, merge_inference_pools_inplace
|
|
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
from petals.server.block_utils import get_block_size, resolve_block_dtype
|
|
|
|
+from petals.server.from_pretrained import DTYPE_MAP, load_pretrained_block
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
from petals.server.memory_cache import MemoryCache
|
|
from petals.server.memory_cache import MemoryCache
|
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
from petals.server.throughput import get_dtype_name, get_server_throughput
|
|
|
|
+from petals.utils.auto_config import AutoDistributedConfig
|
|
from petals.utils.convert_block import check_device_balance, convert_block
|
|
from petals.utils.convert_block import check_device_balance, convert_block
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
from petals.utils.disk_cache import DEFAULT_CACHE_DIR
|
|
|
|
+from petals.utils.version import get_compatible_model_repo
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@@ -53,7 +55,7 @@ class Server:
|
|
max_batch_size: int = 2048,
|
|
max_batch_size: int = 2048,
|
|
inference_max_length: int = 2048,
|
|
inference_max_length: int = 2048,
|
|
torch_dtype: str = "auto",
|
|
torch_dtype: str = "auto",
|
|
- revision: str = "main",
|
|
|
|
|
|
+ revision: Optional[str] = None,
|
|
cache_dir: Optional[str] = None,
|
|
cache_dir: Optional[str] = None,
|
|
max_disk_space: Optional[int] = None,
|
|
max_disk_space: Optional[int] = None,
|
|
attn_cache_tokens: int = 8192,
|
|
attn_cache_tokens: int = 8192,
|
|
@@ -83,25 +85,32 @@ class Server:
|
|
):
|
|
):
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
|
|
|
|
|
|
+ converted_model_name_or_path = get_compatible_model_repo(converted_model_name_or_path)
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
self.converted_model_name_or_path = converted_model_name_or_path
|
|
|
|
+
|
|
self.num_handlers = num_handlers
|
|
self.num_handlers = num_handlers
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
self.inference_max_length = inference_max_length
|
|
self.inference_max_length = inference_max_length
|
|
self.compression = compression
|
|
self.compression = compression
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
- self.use_auth_token = use_auth_token
|
|
|
|
|
|
+ self.revision, self.use_auth_token = revision, use_auth_token
|
|
|
|
|
|
if custom_module_path is not None:
|
|
if custom_module_path is not None:
|
|
add_custom_models_from_file(custom_module_path)
|
|
add_custom_models_from_file(custom_module_path)
|
|
|
|
|
|
|
|
+ self.block_config = AutoDistributedConfig.from_pretrained(
|
|
|
|
+ converted_model_name_or_path,
|
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
|
+ revision=revision,
|
|
|
|
+ )
|
|
|
|
+
|
|
if prefix is None:
|
|
if prefix is None:
|
|
- prefix = converted_model_name_or_path
|
|
|
|
- assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
|
|
|
- f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
|
|
|
|
- f"Please specify --prefix manually when starting a server"
|
|
|
|
- )
|
|
|
|
- logger.debug(f"Automatic dht prefix: {prefix}")
|
|
|
|
|
|
+ prefix = self.block_config.dht_prefix
|
|
|
|
+ assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
|
|
|
|
+ f"DHT prefix should not contain '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'. "
|
|
|
|
+ f"Please specify another --prefix manually when starting a server"
|
|
|
|
+ )
|
|
self.prefix = prefix
|
|
self.prefix = prefix
|
|
|
|
|
|
if expiration is None:
|
|
if expiration is None:
|
|
@@ -111,12 +120,9 @@ class Server:
|
|
self.request_timeout = request_timeout
|
|
self.request_timeout = request_timeout
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
self.session_timeout, self.step_timeout = session_timeout, step_timeout
|
|
|
|
|
|
- self.block_config = BloomConfig.from_pretrained(
|
|
|
|
- converted_model_name_or_path,
|
|
|
|
- use_auth_token=use_auth_token,
|
|
|
|
- revision=revision,
|
|
|
|
- )
|
|
|
|
- self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
+ self.module_uids = [
|
|
|
|
+ f"{self.prefix}.{block_index}" for block_index in range(self.block_config.num_hidden_layers)
|
|
|
|
+ ]
|
|
|
|
|
|
if dht_client_mode is None:
|
|
if dht_client_mode is None:
|
|
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
|
|
is_reachable = check_direct_reachability(initial_peers=initial_peers, use_relay=False, **kwargs)
|
|
@@ -125,7 +131,7 @@ class Server:
|
|
self.dht = DHT(
|
|
self.dht = DHT(
|
|
initial_peers=initial_peers,
|
|
initial_peers=initial_peers,
|
|
start=True,
|
|
start=True,
|
|
- num_workers=self.block_config.n_layer,
|
|
|
|
|
|
+ num_workers=self.block_config.num_hidden_layers,
|
|
use_relay=use_relay,
|
|
use_relay=use_relay,
|
|
use_auto_relay=use_auto_relay,
|
|
use_auto_relay=use_auto_relay,
|
|
client_mode=dht_client_mode,
|
|
client_mode=dht_client_mode,
|
|
@@ -161,10 +167,10 @@ class Server:
|
|
if load_in_8bit is None:
|
|
if load_in_8bit is None:
|
|
load_in_8bit = device.type == "cuda"
|
|
load_in_8bit = device.type == "cuda"
|
|
self.load_in_8bit = load_in_8bit
|
|
self.load_in_8bit = load_in_8bit
|
|
- logger.info(f"Model weights will be loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
|
|
|
|
|
|
+ logger.info(f"Model weights are loaded in {get_dtype_name(torch_dtype, load_in_8bit)} format")
|
|
|
|
|
|
- max_values_in_cache = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
- self._cache_bytes_per_block = max_values_in_cache * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
|
|
+ cache_values_per_block = 2 * self.block_config.hidden_size * attn_cache_tokens
|
|
|
|
+ self._cache_bytes_per_block = cache_values_per_block * torch.finfo(self.torch_dtype).bits // 8
|
|
|
|
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
assert num_blocks is None or block_indices is None, "Please specify num_blocks or block_indices, not both"
|
|
if num_blocks is None and block_indices is None:
|
|
if num_blocks is None and block_indices is None:
|
|
@@ -192,6 +198,7 @@ class Server:
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
if throughput in ["auto", "eval"]:
|
|
if throughput in ["auto", "eval"]:
|
|
throughput = get_server_throughput(
|
|
throughput = get_server_throughput(
|
|
|
|
+ converted_model_name_or_path,
|
|
self.block_config,
|
|
self.block_config,
|
|
device,
|
|
device,
|
|
torch_dtype,
|
|
torch_dtype,
|
|
@@ -239,11 +246,12 @@ class Server:
|
|
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
|
|
num_blocks = math.floor((total_memory - autograd_memory) / (block_size + self._cache_bytes_per_block))
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
assert num_blocks >= 1, "Your GPU does not have enough memory to serve at least one block"
|
|
|
|
|
|
|
|
+ num_blocks = min(num_blocks, self.block_config.num_hidden_layers)
|
|
logger.info(
|
|
logger.info(
|
|
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
|
f"Server will fill all your GPU memory with {num_blocks} transformer blocks. "
|
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
|
f"If you want to leave some free GPU memory, please specify a lesser --num_blocks manually"
|
|
)
|
|
)
|
|
- return min(num_blocks, self.block_config.n_layer)
|
|
|
|
|
|
+ return num_blocks
|
|
|
|
|
|
def run(self):
|
|
def run(self):
|
|
while True:
|
|
while True:
|
|
@@ -274,6 +282,7 @@ class Server:
|
|
step_timeout=self.step_timeout,
|
|
step_timeout=self.step_timeout,
|
|
prefetch_batches=self.prefetch_batches,
|
|
prefetch_batches=self.prefetch_batches,
|
|
sender_threads=self.sender_threads,
|
|
sender_threads=self.sender_threads,
|
|
|
|
+ revision=self.revision,
|
|
use_auth_token=self.use_auth_token,
|
|
use_auth_token=self.use_auth_token,
|
|
load_in_8bit=self.load_in_8bit,
|
|
load_in_8bit=self.load_in_8bit,
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
tensor_parallel_devices=self.tensor_parallel_devices,
|
|
@@ -352,7 +361,7 @@ class ModuleContainer(threading.Thread):
|
|
dht: DHT,
|
|
dht: DHT,
|
|
prefix: str,
|
|
prefix: str,
|
|
converted_model_name_or_path: str,
|
|
converted_model_name_or_path: str,
|
|
- block_config: BloomConfig,
|
|
|
|
|
|
+ block_config: PretrainedConfig,
|
|
attn_cache_bytes: int,
|
|
attn_cache_bytes: int,
|
|
alloc_timeout: float,
|
|
alloc_timeout: float,
|
|
throughput: float,
|
|
throughput: float,
|
|
@@ -366,6 +375,7 @@ class ModuleContainer(threading.Thread):
|
|
compression: CompressionType,
|
|
compression: CompressionType,
|
|
update_period: float,
|
|
update_period: float,
|
|
expiration: Optional[float],
|
|
expiration: Optional[float],
|
|
|
|
+ revision: Optional[str],
|
|
use_auth_token: Optional[str],
|
|
use_auth_token: Optional[str],
|
|
load_in_8bit: bool,
|
|
load_in_8bit: bool,
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
tensor_parallel_devices: Sequence[torch.device],
|
|
@@ -394,14 +404,14 @@ class ModuleContainer(threading.Thread):
|
|
block = load_pretrained_block(
|
|
block = load_pretrained_block(
|
|
converted_model_name_or_path,
|
|
converted_model_name_or_path,
|
|
block_index,
|
|
block_index,
|
|
- block_config,
|
|
|
|
|
|
+ config=block_config,
|
|
torch_dtype=torch_dtype,
|
|
torch_dtype=torch_dtype,
|
|
|
|
+ revision=revision,
|
|
use_auth_token=use_auth_token,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
cache_dir=cache_dir,
|
|
max_disk_space=max_disk_space,
|
|
max_disk_space=max_disk_space,
|
|
)
|
|
)
|
|
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
|
|
block = convert_block(block, block_config, tensor_parallel_devices, device, load_in_8bit, freeze=True)
|
|
-
|
|
|
|
blocks[module_uid] = TransformerBackend(
|
|
blocks[module_uid] = TransformerBackend(
|
|
module_uid,
|
|
module_uid,
|
|
block,
|
|
block,
|
|
@@ -564,13 +574,9 @@ class ModuleContainer(threading.Thread):
|
|
|
|
|
|
self.ready.clear()
|
|
self.ready.clear()
|
|
|
|
|
|
|
|
+ logger.debug("Shutting down connection handlers")
|
|
for handler in self.conn_handlers:
|
|
for handler in self.conn_handlers:
|
|
handler.shutdown()
|
|
handler.shutdown()
|
|
- logger.debug("Connection handlers terminated")
|
|
|
|
-
|
|
|
|
- if self.checkpoint_saver is not None:
|
|
|
|
- self.checkpoint_saver.stop.set()
|
|
|
|
- self.checkpoint_saver.join()
|
|
|
|
|
|
|
|
logger.debug(f"Shutting down pools")
|
|
logger.debug(f"Shutting down pools")
|
|
for pool in self.runtime.pools:
|
|
for pool in self.runtime.pools:
|