Quellcode durchsuchen

Ignore network RPS if we failed to measure it (#198)

Alexander Borzunov vor 2 Jahren
Ursprung
Commit
127cf66bee
1 geänderte Dateien mit 17 neuen und 17 gelöschten Zeilen
  1. 17 17
      src/petals/server/throughput.py

+ 17 - 17
src/petals/server/throughput.py

@@ -101,25 +101,25 @@ def measure_throughput_info(
     logger.info(
         "Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
     )
-    return min(
-        measure_network_rps(config),
-        measure_compute_rps(
-            config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
-        ),
-    )
-
 
-def measure_network_rps(config: BloomConfig) -> float:
+    result = measure_compute_rps(
+        config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
+    )
     try:
-        s = speedtest.Speedtest()
-        s.get_servers()
-        s.get_best_server()
-        s.download()
-        s.upload()
-        network_info = s.results.dict()
-    except:
-        logger.error("Failed to measure network throughput:")
-        raise
+        result = min(result, measure_network_rps(config))
+    except Exception:
+        logger.warning("Failed to measure network throughput:", exc_info=True)
+        logger.warning("Proceeding with the compute throughput only")
+    return result
+
+
+def measure_network_rps(config: BloomConfig) -> Optional[float]:
+    s = speedtest.Speedtest()
+    s.get_servers()
+    s.get_best_server()
+    s.download()
+    s.upload()
+    network_info = s.results.dict()
 
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request