瀏覽代碼

Divide compute throughput by average no. of used blocks (#314)

See #192.
Alexander Borzunov 2 年之前
父節點
當前提交
d9e7bfc949
共有 2 個文件被更改,包括 31 次插入17 次删除
  1. 3 2
      src/petals/server/server.py
  2. 28 15
      src/petals/server/throughput.py

+ 3 - 2
src/petals/server/server.py

@@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
-from petals.server.throughput import get_dtype_name, get_host_throughput
+from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
@@ -193,10 +193,11 @@ class Server:
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
-            throughput = get_host_throughput(
+            throughput = get_server_throughput(
                 self.block_config,
                 device,
                 torch_dtype,
+                num_blocks=num_blocks,
                 load_in_8bit=load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 force_eval=(throughput == "eval"),

+ 28 - 15
src/petals/server/throughput.py

@@ -1,11 +1,12 @@
 import fcntl
 import json
+import math
 import os
 import time
 from collections import Counter
 from hashlib import sha256
 from pathlib import Path
-from typing import Optional, Sequence, Union
+from typing import Dict, Optional, Sequence, Union
 
 import torch
 from hivemind.utils.logging import get_logger
@@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"):
     )
 
 
-def get_host_throughput(
+def get_server_throughput(
     config: BloomConfig,
     device: torch.device,
     dtype: Union[str, torch.dtype],
     *,
+    num_blocks: int,
     load_in_8bit: bool,
     tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
@@ -47,7 +49,7 @@ def get_host_throughput(
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
     lock_path = Path(cache_dir, "throughput.lock")
-    cache_path = Path(cache_dir, "throughput_v2.json")
+    cache_path = Path(cache_dir, "throughput_v3.json")
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
@@ -85,7 +87,16 @@ def get_host_throughput(
             except Exception:
                 logger.exception(f"Failed to save throughput info in {cache_path}")
 
-    return cache[cache_key]
+    throughput_info = cache[cache_key]
+
+    # Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
+    # Assuming the start block index is distributed uniformly, the average number of blocks used per request is
+    # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
+    average_blocks_used = (num_blocks + 1) / 2
+    throughput = throughput_info["compute_rps"] / average_blocks_used
+    throughput = min(throughput, throughput_info.get("network_rps", math.inf))
+    logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
+    return throughput
 
 
 def measure_throughput_info(
@@ -95,22 +106,24 @@ def measure_throughput_info(
     *,
     load_in_8bit: bool,
     tensor_parallel_devices: Sequence[torch.device],
-) -> float:
+) -> Dict[str, float]:
     """Measure network and compute throughput in forward pass tokens per second"""
 
     logger.info(
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
     )
 
-    result = measure_compute_rps(
-        config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
-    )
+    throughput_info = {
+        "compute_rps": measure_compute_rps(
+            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+        )
+    }
     try:
-        result = min(result, measure_network_rps(config))
+        throughput_info["network_rps"] = measure_network_rps(config)
     except Exception:
         logger.warning("Failed to measure network throughput:", exc_info=True)
         logger.warning("Proceeding with the compute throughput only")
-    return result
+    return throughput_info
 
 
 def measure_network_rps(config: BloomConfig) -> Optional[float]:
@@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
         raise ValueError("speedtest has returned network_rps == 0")
 
     logger.info(
-        f"Network throughput: "
-        f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
-        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
-        f"{network_rps:.1f} RPS"
+        f"Network throughput: {network_rps:.1f} RPS "
+        f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
+        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
     )
     return network_rps
 
@@ -168,7 +180,8 @@ def measure_compute_rps(
         devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
 
     logger.info(
-        f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
+        f"Forward pass throughput: {device_rps:.1f} RPS per block "
+        f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
     )
     return device_rps