|
@@ -64,14 +64,14 @@ class Server:
|
|
|
expiration: Optional[float] = None,
|
|
|
request_timeout: float = 3 * 60,
|
|
|
session_timeout: float = 30 * 60,
|
|
|
- step_timeout: float = 5 * 60,
|
|
|
+ step_timeout: float = 60,
|
|
|
prefetch_batches: int = 1,
|
|
|
sender_threads: int = 1,
|
|
|
balance_quality: float = 0.75,
|
|
|
mean_balance_check_period: float = 60,
|
|
|
mean_block_selection_delay: float = 0.5,
|
|
|
use_auth_token: Optional[str] = None,
|
|
|
- load_in_8bit: bool = False,
|
|
|
+ load_in_8bit: Optional[bool] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""Create a server with one or more bloom blocks. See run_server.py for documentation."""
|
|
@@ -81,12 +81,10 @@ class Server:
|
|
|
self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
|
|
|
self.inference_max_length = inference_max_length
|
|
|
self.cache_dir = cache_dir
|
|
|
- self.attn_cache_size = attn_cache_size
|
|
|
self.compression = compression
|
|
|
self.stats_report_interval, self.update_period = stats_report_interval, update_period
|
|
|
self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
|
|
|
self.use_auth_token = use_auth_token
|
|
|
- self.load_in_8bit = load_in_8bit
|
|
|
|
|
|
if custom_module_path is not None:
|
|
|
add_custom_models_from_file(custom_module_path)
|
|
@@ -114,15 +112,16 @@ class Server:
|
|
|
else:
|
|
|
logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
|
|
|
|
|
|
- device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+ if device is None:
|
|
|
+ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+ device = torch.device(device)
|
|
|
self.device = device
|
|
|
|
|
|
- self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
-
|
|
|
- if isinstance(torch_dtype, str):
|
|
|
- torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
- assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
- self.torch_dtype = torch_dtype
|
|
|
+ if load_in_8bit is None:
|
|
|
+ load_in_8bit = device.type == "cuda"
|
|
|
+ if load_in_8bit:
|
|
|
+ logger.info("Model weights will be loaded in 8-bit format")
|
|
|
+ self.load_in_8bit = load_in_8bit
|
|
|
|
|
|
self.block_config = BloomConfig.from_pretrained(
|
|
|
converted_model_name_or_path,
|
|
@@ -131,13 +130,6 @@ class Server:
|
|
|
)
|
|
|
self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
|
|
|
|
|
|
- assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
- if throughput in ["auto", "eval"]:
|
|
|
- throughput = get_host_throughput(
|
|
|
- self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
|
|
|
- )
|
|
|
- self.throughput = throughput
|
|
|
-
|
|
|
assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
|
|
|
if block_indices is not None:
|
|
|
try:
|
|
@@ -147,7 +139,28 @@ class Server:
|
|
|
logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
|
|
|
raise
|
|
|
block_indices = range(first_block_index, last_block_index)
|
|
|
+ num_blocks = len(block_indices)
|
|
|
self.strict_block_indices, self.num_blocks = block_indices, num_blocks
|
|
|
+
|
|
|
+ gib = 1024**3
|
|
|
+ if attn_cache_size is None:
|
|
|
+ # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
|
|
|
+ attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
|
|
|
+ logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
|
|
|
+ self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
|
|
|
+
|
|
|
+ if isinstance(torch_dtype, str):
|
|
|
+ torch_dtype = DTYPE_MAP[torch_dtype]
|
|
|
+ assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
|
|
|
+ self.torch_dtype = torch_dtype
|
|
|
+
|
|
|
+ assert isinstance(throughput, float) or throughput in ["auto", "eval"]
|
|
|
+ if throughput in ["auto", "eval"]:
|
|
|
+ throughput = get_host_throughput(
|
|
|
+ self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
|
|
|
+ )
|
|
|
+ self.throughput = throughput
|
|
|
+
|
|
|
self.balance_quality = balance_quality
|
|
|
self.mean_balance_check_period = mean_balance_check_period
|
|
|
self.mean_block_selection_delay = mean_block_selection_delay
|
|
@@ -213,7 +226,6 @@ class Server:
|
|
|
def _choose_blocks(self) -> List[int]:
|
|
|
if self.strict_block_indices is not None:
|
|
|
return self.strict_block_indices
|
|
|
- assert self.num_blocks is not None
|
|
|
|
|
|
# If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
|
|
|
# this delay decreases the probability of a race condition while choosing the best blocks to serve.
|