|
@@ -101,25 +101,25 @@ def measure_throughput_info(
|
|
|
logger.info(
|
|
|
"Measuring network and compute throughput. This takes about a minute and will be cached for future runs"
|
|
|
)
|
|
|
- return min(
|
|
|
- measure_network_rps(config),
|
|
|
- measure_compute_rps(
|
|
|
- config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
- ),
|
|
|
- )
|
|
|
-
|
|
|
|
|
|
-def measure_network_rps(config: BloomConfig) -> float:
|
|
|
+ result = measure_compute_rps(
|
|
|
+ config, device, dtype, load_in_8bit=load_in_8bit, tensor_parallel_devices=tensor_parallel_devices
|
|
|
+ )
|
|
|
try:
|
|
|
- s = speedtest.Speedtest()
|
|
|
- s.get_servers()
|
|
|
- s.get_best_server()
|
|
|
- s.download()
|
|
|
- s.upload()
|
|
|
- network_info = s.results.dict()
|
|
|
- except:
|
|
|
- logger.error("Failed to measure network throughput:")
|
|
|
- raise
|
|
|
+ result = min(result, measure_network_rps(config))
|
|
|
+ except Exception:
|
|
|
+ logger.warning("Failed to measure network throughput:", exc_info=True)
|
|
|
+ logger.warning("Proceeding with the compute throughput only")
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def measure_network_rps(config: BloomConfig) -> Optional[float]:
|
|
|
+ s = speedtest.Speedtest()
|
|
|
+ s.get_servers()
|
|
|
+ s.get_best_server()
|
|
|
+ s.download()
|
|
|
+ s.upload()
|
|
|
+ network_info = s.results.dict()
|
|
|
|
|
|
bits_per_request = config.hidden_size * 16 # Clients usually send 16-bit tensors for forward/backward
|
|
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|