Browse Source

Suppress quantization warning and fix dtype defaults in compute benchmark (#117)

Alexander Borzunov 2 years ago
parent
commit
f72c220404

+ 7 - 5
src/petals/cli/run_server.py

@@ -107,9 +107,10 @@ def main():
     parser.add_argument("--mean_balance_check_period", type=float, default=60,
                         help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
 
-    parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
-    parser.add_argument('--load_in_8bit', type=bool, default=None,
-                        help="Convert the loaded model into mixed-8bit quantized model. Default: True if GPU is available")
+    parser.add_argument("--use_auth_token", action='store_true', help="auth token for from_pretrained")
+    parser.add_argument('--load_in_8bit', type=str, default=None,
+                        help="Convert the loaded model into mixed-8bit quantized model. "
+                             "Default: True if GPU is available. Use `--load_in_8bit False` to disable this")
 
     # fmt:on
     args = vars(parser.parse_args())
@@ -133,8 +134,9 @@ def main():
     if args.pop("new_swarm"):
         args["initial_peers"] = []
 
-    use_auth_token = args.pop("use_auth_token")
-    args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
+    load_in_8bit = args.pop("load_in_8bit")
+    if load_in_8bit is not None:
+        args["load_in_8bit"] = load_in_8bit.lower() in ["true", "1"]
 
     server = Server(**args, compression=compression, attn_cache_size=attn_cache_size)
     try:

+ 16 - 13
src/petals/server/throughput.py

@@ -34,6 +34,12 @@ def get_host_throughput(
     cache_path: str = DEFAULT_CACHE_PATH,
     lock_path: str = DEFAULT_LOCK_PATH,
 ) -> float:
+    # Resolve default dtypes
+    if dtype == "auto" or dtype is None:
+        dtype = config.torch_dtype
+        if dtype == "auto" or dtype is None:
+            dtype = torch.float32
+
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     with open(lock_path, "wb") as lock_fd:
@@ -42,8 +48,8 @@ def get_host_throughput(
         # The OS will release the lock when lock_fd is closed or the process is killed
 
         cache_key = f"config_{sha256(str(config).encode()).hexdigest()[-16:]}"
-        cache_key += f"_device_{_get_device_name(device).replace(' ', '_')}"
-        cache_key += f"_dtype_{_get_dtype_name(dtype, load_in_8bit)}"
+        cache_key += f"_device_{get_device_name(device).replace(' ', '_')}"
+        cache_key += f"_dtype_{get_dtype_name(dtype, load_in_8bit)}"
 
         cache = {}
         try:
@@ -71,7 +77,7 @@ def get_host_throughput(
 def measure_throughput_info(
     config: BloomConfig,
     device: torch.device,
-    dtype: Union[str, torch.dtype],
+    dtype: torch.dtype,
     *,
     load_in_8bit: bool,
 ) -> float:
@@ -107,7 +113,7 @@ def measure_network_rps(config: BloomConfig) -> float:
 def measure_compute_rps(
     config: BloomConfig,
     device: torch.device,
-    dtype: Union[str, torch.dtype],
+    dtype: torch.dtype,
     *,
     load_in_8bit: bool,
     n_tokens: int = 16,
@@ -115,10 +121,7 @@ def measure_compute_rps(
     layer_index: int = 0,
 ) -> float:
     with torch.inference_mode():
-        block = BloomBlock(config, layer_index)
-        if dtype != "auto":
-            block = block.to(dtype)
-        input_dtype = block.input_layernorm.weight.dtype
+        block = BloomBlock(config, layer_index).to(dtype)
         if load_in_8bit:
             block = replace_8bit_linear(block)
         block = block.to(device)
@@ -126,8 +129,8 @@ def measure_compute_rps(
         cache = None
         elapsed = 0
         for step in range(n_steps + 1):
-            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=input_dtype)
-            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=input_dtype)
+            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
+            alibi = build_alibi_tensor(step + 1, config.num_attention_heads, device=device, dtype=dtype)
 
             start_time = time.perf_counter()
             _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
@@ -136,15 +139,15 @@ def measure_compute_rps(
         device_rps = n_steps * n_tokens / elapsed
 
     logger.info(
-        f"Forward pass throughput ({_get_device_name(device)}, {_get_dtype_name(dtype, load_in_8bit)}): "
+        f"Forward pass throughput ({get_device_name(device)}, {get_dtype_name(dtype, load_in_8bit)}): "
         f"{device_rps:.1f} RPS"
     )
     return device_rps
 
 
-def _get_device_name(device: torch.device) -> str:
+def get_device_name(device: torch.device) -> str:
     return f"{torch.cuda.get_device_name(device)} GPU" if device == "cuda" else "CPU"
 
 
-def _get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
+def get_dtype_name(dtype: torch.dtype, load_in_8bit: bool) -> str:
     return "8-bit" if load_in_8bit else str(dtype)

+ 2 - 2
src/petals/utils/linear8bitlt_patch.py

@@ -9,7 +9,7 @@ Based on: https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu#
 Exact match tests: see $REPO/tests/test_linear8bitlt.py
 """
 import dataclasses
-import warnings
+import logging
 from typing import Optional, Tuple
 
 import bitsandbytes.functional as F
@@ -155,7 +155,7 @@ class CustomMatMul8bitLt(MatMul8bitLt):
 
         # Cast A to fp16
         if A.dtype != torch.float16:
-            warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
+            logging.debug(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
 
         # 1. Quantize A
         if len(A.shape) == 3: