Selaa lähdekoodia

Fix missing torch.cuda.synchronize for computing throughput (#456)

justheuristic 2 vuotta sitten
vanhempi
commit
55eb36ef48
1 muutettua tiedostoa jossa 11 lisäystä ja 6 poistoa
  1. 11 6
      src/petals/server/throughput.py

+ 11 - 6
src/petals/server/throughput.py

@@ -51,7 +51,7 @@ def get_server_throughput(
     if cache_dir is None:
         cache_dir = DEFAULT_CACHE_DIR
     lock_path = Path(cache_dir, "throughput.lock")
-    cache_path = Path(cache_dir, "throughput_v4.json")
+    cache_path = Path(cache_dir, "throughput_v5.json")
 
     # We use the system-wide lock since only one process at a time can measure the host throughput
     os.makedirs(lock_path.parent, exist_ok=True)
@@ -196,6 +196,7 @@ def measure_compute_rps(
     n_steps: int,
     inference: bool,
 ) -> float:
+    device = torch.device(device)
     if not tensor_parallel_devices:
         tensor_parallel_devices = (device,)
     with torch.inference_mode():
@@ -204,13 +205,17 @@ def measure_compute_rps(
 
         cache = None
         elapsed = 0
-        for step in range(n_steps + 1):
-            dummy_input = torch.randn(n_tokens, 1, config.hidden_size, device=device, dtype=dtype)
+        dummy_input = torch.randn(1, n_tokens, config.hidden_size, device=device, dtype=dtype)
+        _, cache = block.forward(dummy_input, use_cache=True)  # Skip the 1st step to exclude the initialization time
+        if device.type == "cuda":
+            torch.cuda.synchronize(device)
 
-            start_time = time.perf_counter()
+        start_time = time.perf_counter()
+        for step in range(n_steps):
             _, cache = block.forward(dummy_input, use_cache=True, layer_past=cache if inference else None)
-            if step >= 1:  # Skip the 1st step to exclude the initialization time
-                elapsed += time.perf_counter() - start_time
+        if device.type == "cuda":
+            torch.cuda.synchronize(device)
+        elapsed = time.perf_counter() - start_time
         device_rps = n_steps * n_tokens / elapsed
 
     devices_repr = get_device_name(device)