5
0
Эх сурвалжийг харах

Divide compute throughput by average no. of used blocks (#314)

See #192.
Alexander Borzunov 2 жил өмнө
parent
commit
d9e7bfc949

+ 3 - 2
src/petals/server/server.py

@@ -27,7 +27,7 @@ from petals.server.block_utils import get_block_size
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.handler import TransformerConnectionHandler
 from petals.server.memory_cache import MemoryCache
 from petals.server.memory_cache import MemoryCache
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
 from petals.server.reachability import ReachabilityProtocol, check_direct_reachability, validate_reachability
-from petals.server.throughput import get_dtype_name, get_host_throughput
+from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.convert_block import check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
 
 
@@ -193,10 +193,11 @@ class Server:
 
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
         if throughput in ["auto", "eval"]:
-            throughput = get_host_throughput(
+            throughput = get_server_throughput(
                 self.block_config,
                 self.block_config,
                 device,
                 device,
                 torch_dtype,
                 torch_dtype,
+                num_blocks=num_blocks,
                 load_in_8bit=load_in_8bit,
                 load_in_8bit=load_in_8bit,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 force_eval=(throughput == "eval"),
                 force_eval=(throughput == "eval"),

+ 28 - 15
src/petals/server/throughput.py

@@ -1,11 +1,12 @@
 import fcntl
 import fcntl
 import json
 import json
+import math
 import os
 import os
 import time
 import time
 from collections import Counter
 from collections import Counter
 from hashlib import sha256
 from hashlib import sha256
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Sequence, Union
+from typing import Dict, Optional, Sequence, Union
 
 
 import torch
 import torch
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
@@ -32,11 +33,12 @@ if not hasattr(speedtest, "Speedtest"):
     )
     )
 
 
 
 
-def get_host_throughput(
+def get_server_throughput(
     config: BloomConfig,
     config: BloomConfig,
     device: torch.device,
     device: torch.device,
     dtype: Union[str, torch.dtype],
     dtype: Union[str, torch.dtype],
     *,
     *,
+    num_blocks: int,
     load_in_8bit: bool,
     load_in_8bit: bool,
     tensor_parallel_devices: Sequence[torch.device],
     tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
     force_eval: bool = False,
@@ -47,7 +49,7 @@ def get_host_throughput(
     if cache_dir is None:
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
         cache_dir = DEFAULT_CACHE_DIR
     lock_path = Path(cache_dir, "throughput.lock")
     lock_path = Path(cache_dir, "throughput.lock")
-    cache_path = Path(cache_dir, "throughput_v2.json")
+    cache_path = Path(cache_dir, "throughput_v3.json")
 
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
     os.makedirs(lock_path.parent, exist_ok=True)
@@ -85,7 +87,16 @@ def get_host_throughput(
             except Exception:
             except Exception:
                 logger.exception(f"Failed to save throughput info in {cache_path}")
                 logger.exception(f"Failed to save throughput info in {cache_path}")
 
 
-    return cache[cache_key]
+    throughput_info = cache[cache_key]
+
+    # Most requests start at some block hosted by a server, then use all next blocks hosted on this server.
+    # Assuming the start block index is distributed uniformly, the average number of blocks used per request is
+    # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
+    average_blocks_used = (num_blocks + 1) / 2
+    throughput = throughput_info["compute_rps"] / average_blocks_used
+    throughput = min(throughput, throughput_info.get("network_rps", math.inf))
+    logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
+    return throughput
 
 
 
 
 def measure_throughput_info(
 def measure_throughput_info(
@@ -95,22 +106,24 @@ def measure_throughput_info(
     *,
     *,
     load_in_8bit: bool,
     load_in_8bit: bool,
     tensor_parallel_devices: Sequence[torch.device],
     tensor_parallel_devices: Sequence[torch.device],
-) -> float:
+) -> Dict[str, float]:
     """Measure network and compute throughput in forward pass tokens per second"""
     """Measure network and compute throughput in forward pass tokens per second"""
 
 
     logger.info(
     logger.info(
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
     )
     )
 
 
-    result = measure_compute_rps(
-        config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
-    )
+    throughput_info = {
+        "compute_rps": measure_compute_rps(
+            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+        )
+    }
     try:
     try:
-        result = min(result, measure_network_rps(config))
+        throughput_info["network_rps"] = measure_network_rps(config)
     except Exception:
     except Exception:
         logger.warning("Failed to measure network throughput:", exc_info=True)
         logger.warning("Failed to measure network throughput:", exc_info=True)
         logger.warning("Proceeding with the compute throughput only")
         logger.warning("Proceeding with the compute throughput only")
-    return result
+    return throughput_info
 
 
 
 
 def measure_network_rps(config: BloomConfig) -> Optional[float]:
 def measure_network_rps(config: BloomConfig) -> Optional[float]:
@@ -127,10 +140,9 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
         raise ValueError("speedtest has returned network_rps == 0")
         raise ValueError("speedtest has returned network_rps == 0")
 
 
     logger.info(
     logger.info(
-        f"Network throughput: "
-        f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
-        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
-        f"{network_rps:.1f} RPS"
+        f"Network throughput: {network_rps:.1f} RPS "
+        f"({network_info['download'] / 1e6:.2f} Mbit/s on download, "
+        f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload)"
     )
     )
     return network_rps
     return network_rps
 
 
@@ -168,7 +180,8 @@ def measure_compute_rps(
         devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
         devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
 
 
     logger.info(
     logger.info(
-        f"Forward pass throughput ({devices_repr}, {get_dtype_name(dtype, load_in_8bit)}): " f"{device_rps:.1f} RPS"
+        f"Forward pass throughput: {device_rps:.1f} RPS per block "
+        f"({devices_repr}, {get_dtype_name(dtype, load_in_8bit)})"
     )
     )
     return device_rps
     return device_rps