浏览代码

Report inference, forward, and network RPS separately (#358)

Inference RPS may be very different from forward RPS. E.g., currently bnb uses a completely different algorithm for NF4 inference. We report detailed RPS info that can be then used for shortest-path routing for inference.
Alexander Borzunov 2 年之前
父节点
当前提交
11f0d992d7

+ 6 - 3
src/petals/client/routing/sequence_info.py

@@ -73,12 +73,15 @@ class RemoteSequenceInfo:
         active_spans = {}
         for block_index, info in enumerate(block_infos):
             if info is not None:
-                for peer_id, server in info.servers.items():
-                    if server.state != ServerState.ONLINE:
+                for peer_id, server_info in info.servers.items():
+                    if server_info.state != ServerState.ONLINE:
                         continue
                     if peer_id not in active_spans:
                         active_spans[peer_id] = RemoteSpanInfo(
-                            peer_id=peer_id, start=block_index, end=block_index + 1, throughput=server.throughput
+                            peer_id=peer_id,
+                            start=block_index,
+                            end=block_index + 1,
+                            server_info=server_info,
                         )
                     else:  # peer_id in active_spans
                         active_spans[peer_id].end = block_index + 1

+ 1 - 1
src/petals/client/routing/sequence_manager.py

@@ -151,7 +151,7 @@ class RemoteSequenceManager:
                 raise MissingBlocksError(current_index)
 
             if mode == "max_throughput":
-                span_weights = np.array([span.throughput for span in candidate_spans], dtype=np.float64)
+                span_weights = np.array([span.server_info.throughput for span in candidate_spans], dtype=np.float64)
             elif mode == "min_latency":
                 span_weights = np.array([span.end - current_index for span in candidate_spans], dtype=np.float64)
             else:

+ 9 - 2
src/petals/data_structures.py

@@ -19,10 +19,17 @@ class ServerState(Enum):
     ONLINE = 2
 
 
+RPS = pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+
+
 @pydantic.dataclasses.dataclass
 class ServerInfo:
     state: ServerState
-    throughput: pydantic.confloat(ge=0, allow_inf_nan=False, strict=True)
+    throughput: RPS
+
+    network_rps: Optional[RPS] = None
+    forward_rps: Optional[RPS] = None
+    inference_rps: Optional[RPS] = None
 
     adapters: Sequence[str] = ()
     version: Optional[str] = None
@@ -60,7 +67,7 @@ class RemoteSpanInfo:
     peer_id: PeerID
     start: int
     end: int
-    throughput: float
+    server_info: ServerInfo
 
     @property
     def length(self):

+ 4 - 2
src/petals/server/server.py

@@ -206,7 +206,7 @@ class Server:
 
         assert isinstance(throughput, float) or throughput in ["auto", "eval"]
         if throughput in ["auto", "eval"]:
-            throughput = get_server_throughput(
+            throughput_info = get_server_throughput(
                 converted_model_name_or_path,
                 self.block_config,
                 device,
@@ -217,14 +217,16 @@ class Server:
                 force_eval=(throughput == "eval"),
                 cache_dir=cache_dir,
             )
+        else:
+            throughput_info = {"throughput": throughput}
         self.server_info = ServerInfo(
             state=ServerState.JOINING,
-            throughput=throughput,
             adapters=tuple(adapters),
             version=petals.__version__,
             torch_dtype=str(torch_dtype).replace("torch.", ""),
             quant_type=quant_type.name.lower(),
             using_relay=self.dht.client_mode,
+            **throughput_info,
         )
 
         self.balance_quality = balance_quality

+ 45 - 21
src/petals/server/throughput.py

@@ -43,13 +43,13 @@ def get_server_throughput(
     tensor_parallel_devices: Sequence[torch.device],
     force_eval: bool = False,
     cache_dir: Optional[str] = None,
-) -> float:
+) -> Dict[str, float]:
     dtype = resolve_block_dtype(config, dtype)
 
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
     lock_path = Path(cache_dir, "throughput.lock")
-    cache_path = Path(cache_dir, "throughput_v3.json")
+    cache_path = Path(cache_dir, "throughput_v4.json")
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
@@ -93,10 +93,12 @@ def get_server_throughput(
     # Assuming the start block index is distributed uniformly, the average number of blocks used per request is
     # E[Uniform{1, 2, ..., num_blocks}] = (num_blocks + 1) / 2
     average_blocks_used = (num_blocks + 1) / 2
-    throughput = throughput_info["compute_rps"] / average_blocks_used
+    throughput = throughput_info["forward_rps"] / average_blocks_used
     throughput = min(throughput, throughput_info.get("network_rps", math.inf))
+    throughput_info["throughput"] = throughput
     logger.info(f"Reporting throughput: {throughput:.1f} RPS for {num_blocks} blocks")
-    return throughput
+
+    return throughput_info
 
 
 def measure_throughput_info(
@@ -114,15 +116,31 @@ def measure_throughput_info(
     )
 
     throughput_info = {
-        "compute_rps": measure_compute_rps(
-            config, device, dtype, quant_type=quant_type, tensor_parallel_devices=tensor_parallel_devices
-        )
+        "inference_rps": measure_compute_rps(
+            config,
+            device,
+            dtype,
+            quant_type=quant_type,
+            tensor_parallel_devices=tensor_parallel_devices,
+            n_tokens=1,
+            n_steps=100,
+            inference=True,
+        ),
+        "forward_rps": measure_compute_rps(
+            config,
+            device,
+            dtype,
+            quant_type=quant_type,
+            tensor_parallel_devices=tensor_parallel_devices,
+            n_tokens=1024,
+            n_steps=10,
+            inference=False,
+        ),
     }
     try:
         throughput_info["network_rps"] = measure_network_rps(config)
     except Exception as e:
-        logger.warning(f"Failed to measure network throughput: {repr(e)}")
-        logger.warning("Proceeding with the compute throughput only")
+        logger.info(f"Network throughput is not available: {e}")
     return throughput_info
 
 
@@ -135,6 +153,8 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
         process.terminate()
         raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
     network_info = pipe_recv.recv()
+    if "exception" in network_info:
+        raise RuntimeError(f"speedtest failed: {network_info['exception']}")
 
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
@@ -150,12 +170,15 @@ def measure_network_rps(config: PretrainedConfig, *, timeout: float = 60) -> Opt
 
 
 def _measure_bits_per_second(pipe_send: mp.Pipe):
-    s = speedtest.Speedtest()
-    s.get_servers()
-    s.get_best_server()
-    s.download()
-    s.upload()
-    pipe_send.send(s.results.dict())
+    try:
+        s = speedtest.Speedtest()
+        s.get_servers()
+        s.get_best_server()
+        s.download()
+        s.upload()
+        pipe_send.send(s.results.dict())
+    except Exception as e:
+        pipe_send.send({"exception": repr(e)})
 
 
 def measure_compute_rps(
@@ -165,8 +188,9 @@ def measure_compute_rps(
     *,
     quant_type: QuantType,
     tensor_parallel_devices: Sequence[torch.device],
-    n_tokens: int = 16,
-    n_steps: int = 500,
+    n_tokens: int,
+    n_steps: int,
+    inference: bool,
 ) -> float:
     if not tensor_parallel_devices:
         tensor_parallel_devices = (device,)
@@ -180,7 +204,7 @@ def measure_compute_rps(
             dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
 
             start_time = time.perf_counter()
-            _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache)
+            _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
             if step >= 1:  # Skip the 1st step to exclude the initialization time
                 elapsed += time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
@@ -191,8 +215,8 @@ def measure_compute_rps(
         devices_repr = ", ".join(f"{count}x {name}" for name, count in Counter(device_names).most_common())
 
     logger.info(
-        f"Forward pass throughput: {device_rps:.1f} RPS per block "
-        f"({devices_repr}, {get_dtype_name(dtype, quant_type)})"
+        f"{'Inference' if inference else 'Forward pass'} throughput: {device_rps:.1f} RPS per block "
+        f"({n_tokens} tokens/batch, {devices_repr}, {get_dtype_name(dtype, quant_type)})"
     )
     return device_rps
 
@@ -202,7 +226,7 @@ def get_device_name(device: torch.device) -> str:
 
 
 def get_dtype_name(dtype: torch.dtype, quant_type: QuantType) -> str:
-    name = str(dtype)
+    name = str(dtype).replace("torch.", "")
     if quant_type != QuantType.NONE:
         name += f", quantized to {quant_type.name.lower()}"
     return name

+ 1 - 1
src/petals/utils/ping.py

@@ -16,7 +16,7 @@ async def ping(
     _dht: hivemind.DHT,
     node: hivemind.dht.DHTNode,
     *,
-    wait_timeout: float = 1,
+    wait_timeout: float = 5,
 ) -> float:
     try:
         ping_request = dht_pb2.PingRequest(peer=node.protocol.node_info)

+ 6 - 2
tests/test_aux_functions.py

@@ -24,8 +24,10 @@ def test_bnb_not_imported_when_unnecessary():
 
 
 @pytest.mark.forked
+@pytest.mark.parametrize("inference", [False, True])
+@pytest.mark.parametrize("n_tokens", [1, 16])
 @pytest.mark.parametrize("tensor_parallel", [False, True])
-def test_compute_throughput(tensor_parallel: bool):
+def test_compute_throughput(inference: bool, n_tokens: int, tensor_parallel: bool):
     config = AutoDistributedConfig.from_pretrained(MODEL_NAME)
     tensor_parallel_devices = ("cpu", "cpu") if tensor_parallel else ()
     compute_rps = measure_compute_rps(
@@ -34,6 +36,8 @@ def test_compute_throughput(tensor_parallel: bool):
         dtype=torch.bfloat16,
         quant_type=QuantType.NONE,
         tensor_parallel_devices=tensor_parallel_devices,
-        n_steps=10,
+        n_tokens=n_tokens,
+        n_steps=5,
+        inference=inference,
     )
     assert isinstance(compute_rps, float) and compute_rps > 0