|
@@ -41,7 +41,6 @@ def get_host_throughput(
|
|
|
# We only keep the device type, assuming that the throughput is similar among all host's GPUs
|
|
|
device = torch.device(device).type
|
|
|
|
|
|
- info = None
|
|
|
# We use the system-wide lock since only one process at a time can measure the host throughput
|
|
|
os.makedirs(lock_path.parent, exist_ok=True)
|
|
|
with open(lock_path, 'wb') as lock_fd:
|
|
@@ -49,11 +48,16 @@ def get_host_throughput(
|
|
|
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
|
|
|
# The OS will release the lock when lock_fd is closed or the process is killed
|
|
|
|
|
|
- if not force_eval and os.path.exists(cache_path):
|
|
|
- with open(cache_path) as cache_fd:
|
|
|
- info = ThroughputInfo(**json.load(cache_fd))
|
|
|
- if device not in info.device_rps:
|
|
|
- force_eval = True
|
|
|
+ info = None
|
|
|
+ try:
|
|
|
+ if not force_eval and os.path.exists(cache_path):
|
|
|
+ with open(cache_path) as cache_fd:
|
|
|
+ info = ThroughputInfo(**json.load(cache_fd))
|
|
|
+ if device not in info.device_rps:
|
|
|
+ force_eval = True
|
|
|
+ except Exception:
|
|
|
+ logger.exception(f"Failed to read throughput info from {cache_path}")
|
|
|
+ force_eval = True
|
|
|
|
|
|
if force_eval or info is None:
|
|
|
info = measure_throughput_info()
|