浏览代码

Make server use smart defaults (#115)

Summary:

```python
parser.add_argument('--attn_cache_size', type=str, default=None,
                    help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
                         'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
                         'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
                         'The latter is the hidden size of the bigscience/bloom-petals model.')

parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
                    help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
                    help='Timeout (in seconds) for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=60,
                    help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")

parser.add_argument('--load_in_8bit', type=bool, default=None,
                    help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
```

Co-authored-by: justheuristic <justheuristic@gmail.com>
Alexander Borzunov 2 年之前
父节点
当前提交
643a054170
共有 4 个文件被更改,包括 54 次插入38 次删除
  1. 1 4
      src/petals/cli/run_prod_server.sh
  2. 10 7
      src/petals/cli/run_server.py
  3. 31 19
      src/petals/server/server.py
  4. 12 8
      src/petals/server/throughput.py

+ 1 - 4
src/petals/cli/run_prod_server.sh

@@ -5,8 +5,5 @@ export HIVEMIND_COLORS=true
 while true; do
         pkill -f p2p
         pkill -f run_server
-        python -m petals.cli.run_server bigscience/bloom-petals \
-                --block_indices $1 \
-                --torch_dtype bfloat16 --load_in_8bit \
-                --attn_cache_size $2 2>&1 | tee log_`date '+%F_%H:%M:%S'`
+        python -m petals.cli.run_server bigscience/bloom-petals "$@" 2>&1 | tee log_`date '+%F_%H:%M:%S'`
 done

+ 10 - 7
src/petals/cli/run_server.py

@@ -55,8 +55,10 @@ def main():
                         help="Use this dtype to store block weights and do computations. "
                              "By default, respect the dtypes in the pre-trained state dict.")
     parser.add_argument('--attn_cache_size', type=str, default=None,
-                        help='The size of GPU memory allocated for storing past attention keys/values between inference'
-                             ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
+                        help='The size of GPU memory allocated for storing past attention keys/values between inference steps. '
+                             'Examples: 500MB, 1.2GB, 1073741824 (bytes). Note that 1KB != 1KiB here. '
+                             'Default: 0.5GiB * num_blocks * hidden_size / 14336. '
+                             'The latter is the hidden size of the bigscience/bloom-petals model.')
     parser.add_argument('--alloc_timeout', type=float, default=60,
                         help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
                              'before rejecting the request')
@@ -76,11 +78,11 @@ def main():
     parser.add_argument('--expiration', type=float, required=False, default=None,
                         help='DHT entries will expire after this many seconds')
     parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
-                        help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
+                        help='Timeout (in seconds) for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
     parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
-                        help='Timeout for the whole inference session')
-    parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
-                        help="Timeout for waiting the next step's inputs inside an inference session")
+                        help='Timeout (in seconds) for the whole inference session')
+    parser.add_argument('--step_timeout', type=float, required=False, default=60,
+                        help="Timeout (in seconds) for waiting the next step's inputs inside an inference session")
 
     group = parser.add_mutually_exclusive_group()
     group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,
@@ -106,7 +108,8 @@ def main():
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
     parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
+    parser.add_argument('--load_in_8bit', type=bool, default=None,
+                        help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
 
     # fmt:on
     args = vars(parser.parse_args())

+ 31 - 19
src/petals/server/server.py

@@ -64,14 +64,14 @@ class Server:
         expiration: Optional[float] = None,
         request_timeout: float = 3 * 60,
         session_timeout: float = 30 * 60,
-        step_timeout: float = 5 * 60,
+        step_timeout: float = 60,
         prefetch_batches: int = 1,
         sender_threads: int = 1,
         balance_quality: float = 0.75,
         mean_balance_check_period: float = 60,
         mean_block_selection_delay: float = 0.5,
         use_auth_token: Optional[str] = None,
-        load_in_8bit: bool = False,
+        load_in_8bit: Optional[bool] = None,
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -81,12 +81,10 @@ class Server:
         self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
         self.inference_max_length = inference_max_length
         self.cache_dir = cache_dir
-        self.attn_cache_size = attn_cache_size
         self.compression = compression
         self.stats_report_interval, self.update_period = stats_report_interval, update_period
         self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
         self.use_auth_token = use_auth_token
-        self.load_in_8bit = load_in_8bit
 
         if custom_module_path is not None:
             add_custom_models_from_file(custom_module_path)
@@ -114,15 +112,16 @@ class Server:
         else:
             logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
 
-        device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+        if device is None:
+            device = "cuda" if torch.cuda.is_available() else "cpu"
+        device = torch.device(device)
         self.device = device
 
-        self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
-
-        if isinstance(torch_dtype, str):
-            torch_dtype = DTYPE_MAP[torch_dtype]
-        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
-        self.torch_dtype = torch_dtype
+        if load_in_8bit is None:
+            load_in_8bit = device.type == "cuda"
+        if load_in_8bit:
+            logger.info("Model weights will be loaded in 8-bit format")
+        self.load_in_8bit = load_in_8bit
 
         self.block_config = BloomConfig.from_pretrained(
             converted_model_name_or_path,
@@ -131,13 +130,6 @@ class Server:
         )
         self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
 
-        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
-        if throughput in ["auto", "eval"]:
-            throughput = get_host_throughput(
-                self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
-            )
-        self.throughput = throughput
-
         assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
         if block_indices is not None:
             try:
@@ -147,7 +139,28 @@ class Server:
                 logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
                 raise
             block_indices = range(first_block_index, last_block_index)
+            num_blocks = len(block_indices)
         self.strict_block_indices, self.num_blocks = block_indices, num_blocks
+
+        gib = 1024**3
+        if attn_cache_size is None:
+            # Hidden size is 14336 for the bigscience/bloom-petals model. For other models, scale accordingly
+            attn_cache_size = 0.5 * gib * num_blocks * self.block_config.hidden_size / 14336
+        logger.info(f"Attention cache for all blocks will consume up to {attn_cache_size / gib:.2f} GiB")
+        self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
+
+        if isinstance(torch_dtype, str):
+            torch_dtype = DTYPE_MAP[torch_dtype]
+        assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
+        self.torch_dtype = torch_dtype
+
+        assert isinstance(throughput, float) or throughput in ["auto", "eval"]
+        if throughput in ["auto", "eval"]:
+            throughput = get_host_throughput(
+                self.block_config, device, torch_dtype, load_in_8bit=load_in_8bit, force_eval=(throughput == "eval")
+            )
+        self.throughput = throughput
+
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
@@ -213,7 +226,6 @@ class Server:
     def _choose_blocks(self) -> List[int]:
         if self.strict_block_indices is not None:
             return self.strict_block_indices
-        assert self.num_blocks is not None
 
         # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
         # this delay decreases the probability of a race condition while choosing the best blocks to serve.

+ 12 - 8
src/petals/server/throughput.py

@@ -6,6 +6,7 @@ import tempfile
 import time
 from hashlib import sha256
 from pathlib import Path
+from typing import Union
 
 import torch
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -26,7 +27,7 @@ DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), "petals", "throughput.lock")
 def get_host_throughput(
     config: BloomConfig,
     device: torch.device,
-    torch_dtype: torch.dtype,
+    dtype: Union[str, torch.dtype],
     *,
     load_in_8bit: bool,
     force_eval: bool = False,
@@ -42,7 +43,7 @@ def get_host_throughput(
 
         cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
         cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
-        cache_key += f"_dtype_{_get_dtype_name(torch_dtype, load_in_8bit)}"
+        cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
 
         cache = {}
         try:
@@ -55,7 +56,7 @@ def get_host_throughput(
             cache = {}
 
         if cache_key not in cache:
-            cache[cache_key] = measure_throughput_info(config, device, torch_dtype, load_in_8bit=load_in_8bit)
+            cache[cache_key] = measure_throughput_info(config, device, dtype, load_in_8bit=load_in_8bit)
 
             try:
                 os.makedirs(cache_path.parent, exist_ok=True)
@@ -70,7 +71,7 @@ def get_host_throughput(
 def measure_throughput_info(
     config: BloomConfig,
     device: torch.device,
-    dtype: torch.dtype,
+    dtype: Union[str, torch.dtype],
     *,
     load_in_8bit: bool,
 ) -> float:
@@ -106,7 +107,7 @@ def measure_network_rps(config: BloomConfig) -> float:
 def measure_compute_rps(
     config: BloomConfig,
     device: torch.device,
-    dtype: torch.dtype,
+    dtype: Union[str, torch.dtype],
     *,
     load_in_8bit: bool,
     n_tokens: int = 16,
@@ -114,7 +115,10 @@ def measure_compute_rps(
     layer_index: int = 0,
 ) -> float:
     with torch.inference_mode():
-        block = BloomBlock(config, layer_index).to(dtype)
+        block = BloomBlock(config, layer_index)
+        if dtype != "auto":
+            block = block.to(dtype)
+        input_dtype = block.input_layernorm.weight.dtype
         if load_in_8bit:
             block = replace_8bit_linear(block)
         block = block.to(device)
@@ -122,8 +126,8 @@ def measure_compute_rps(
         cache = None
         elapsed = 0
         for step in range(n_steps + 1):
-            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
-            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
+            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
+            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
 
             start_time = time.perf_counter()
             _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)