ソースを参照

Update throughput.py

justheuristic 2 年 前
コミット
b3115dac58
1 ファイル変更1 行追加1 行削除
  1. 1 1
      src/petals/server/throughput.py

+ 1 - 1
src/petals/server/throughput.py

@@ -92,7 +92,7 @@ def measure_network_rps(config: BloomConfig) -> float:
         raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
     network_info = json.loads(proc.stdout)
 
-    bits_per_request = config.hidden_size * 32
+    bits_per_request = config.hidden_size * (16 if config.torch_dtype in (torch.float16, torch.bfloat16) else 32)
     network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
 
     logger.info(