|
@@ -92,7 +92,7 @@ def measure_network_rps(config: BloomConfig) -> float:
|
|
|
raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
|
|
|
network_info = json.loads(proc.stdout)
|
|
|
|
|
|
- bits_per_request = config.hidden_size * 32
|
|
|
+ bits_per_request = config.hidden_size * (16 if config.torch_dtype in (torch.float16, torch.bfloat16) else 32)
|
|
|
network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
|
|
|
|
|
|
logger.info(
|