ソースを参照

Abort speedtest if it runs too long (#316)

Addresses #192 and, specifically, #280.
Alexander Borzunov 2 年 前
コミット
e026952338
1 ファイル変更22 行追加10 行削除
  1. 22 10
      src/petals/server/throughput.py

+ 22 - 10
src/petals/server/throughput.py

@@ -1,6 +1,7 @@
 import fcntl
 import json
 import math
+import multiprocessing as mp
 import os
 import time
 from collections import Counter
@@ -120,24 +121,26 @@ def measure_throughput_info(
     }
     try:
         throughput_info["network_rps"] = measure_network_rps(config)
-    except Exception:
-        logger.warning("Failed to measure network throughput:", exc_info=True)
+    except Exception as e:
+        logger.warning(f"Failed to measure network throughput: {repr(e)}")
         logger.warning("Proceeding with the compute throughput only")
     return throughput_info
 
 
-def measure_network_rps(config: BloomConfig) -> Optional[float]:
-    s = speedtest.Speedtest()
-    s.get_servers()
-    s.get_best_server()
-    s.download()
-    s.upload()
-    network_info = s.results.dict()
+def measure_network_rps(config: BloomConfig, *, timeout: float = 60) -> Optional[float]:
+    pipe_recv, pipe_send = mp.Pipe(duplex=False)
+    process = mp.Process(target=_measure_bits_per_second, args=(pipe_send,))
+    process.start()
+
+    if not pipe_recv.poll(timeout):
+        process.terminate()
+        raise RuntimeError(f"speedtest did not finish in {timeout} seconds")
+    network_info = pipe_recv.recv()
 
     bits_per_request = config.hidden_size * 16  # Clients usually send 16-bit tensors for forward/backward
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
     if network_rps == 0:
-        raise ValueError("speedtest has returned network_rps == 0")
+        raise RuntimeError("speedtest has returned network_rps == 0")
 
     logger.info(
         f"Network throughput: {network_rps:.1f} RPS "
@@ -147,6 +150,15 @@ def measure_network_rps(config: BloomConfig) -> Optional[float]:
     return network_rps
 
 
+def _measure_bits_per_second(pipe_send: mp.Pipe):
+    s = speedtest.Speedtest()
+    s.get_servers()
+    s.get_best_server()
+    s.download()
+    s.upload()
+    pipe_send.send(s.results.dict())
+
+
 def measure_compute_rps(
     config: BloomConfig,
     device: torch.device,