|
@@ -36,7 +36,7 @@ def load_peft(
|
|
|
|
|
|
try:
|
|
|
with allow_cache_reads(cache_dir):
|
|
|
- path = get_file_from_repo(
|
|
|
+ config_path = get_file_from_repo(
|
|
|
repo_id,
|
|
|
CONFIG_NAME,
|
|
|
revision=revision,
|
|
@@ -44,21 +44,37 @@ def load_peft(
|
|
|
cache_dir=cache_dir,
|
|
|
local_files_only=True,
|
|
|
)
|
|
|
- config = PeftConfig.from_json_file(path)
|
|
|
+ assert config_path is not None
|
|
|
+ config = PeftConfig.from_json_file(config_path)
|
|
|
+
|
|
|
+ weight_path = get_file_from_repo(
|
|
|
+ repo_id,
|
|
|
+ SAFETENSORS_WEIGHTS_NAME,
|
|
|
+ revision=revision,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ local_files_only=True,
|
|
|
+ )
|
|
|
+ assert weight_path is not None
|
|
|
+ return config, load_file(weight_path)
|
|
|
except Exception:
|
|
|
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
|
|
|
while True:
|
|
|
try:
|
|
|
with allow_cache_writes(cache_dir):
|
|
|
- url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
|
|
- file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
|
|
+ config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
|
|
+ config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
|
|
+ wieght_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
|
+ wieght_file_size = get_hf_file_metadata(wieght_url, token=use_auth_token).size
|
|
|
+
|
|
|
+ file_size = config_file_size + wieght_file_size
|
|
|
if file_size is not None:
|
|
|
free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
else:
|
|
|
- logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
|
|
|
+ logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
|
|
|
|
|
- path = get_file_from_repo(
|
|
|
+ config_path = get_file_from_repo(
|
|
|
repo_id,
|
|
|
CONFIG_NAME,
|
|
|
revision=revision,
|
|
@@ -66,41 +82,11 @@ def load_peft(
|
|
|
cache_dir=cache_dir,
|
|
|
local_files_only=False,
|
|
|
)
|
|
|
- if path is None:
|
|
|
+ if config_path is None:
|
|
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
|
- config = PeftConfig.from_json_file(path)
|
|
|
- break
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
- time.sleep(delay)
|
|
|
-
|
|
|
- try:
|
|
|
- with allow_cache_reads(cache_dir):
|
|
|
- path = get_file_from_repo(
|
|
|
- repo_id,
|
|
|
- SAFETENSORS_WEIGHTS_NAME,
|
|
|
- revision=revision,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- local_files_only=True,
|
|
|
- )
|
|
|
- if path is not None:
|
|
|
- return config, load_file(path)
|
|
|
- except Exception:
|
|
|
- logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
-
|
|
|
- # If not found, ensure that we have enough disk space to download them (maybe remove something)
|
|
|
- while True:
|
|
|
- try:
|
|
|
- with allow_cache_writes(cache_dir):
|
|
|
- url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
|
- file_size = get_hf_file_metadata(url, token=use_auth_token).size
|
|
|
- if file_size is not None:
|
|
|
- free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
- else:
|
|
|
- logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
|
|
|
+ config = PeftConfig.from_json_file(config_path)
|
|
|
|
|
|
- path = get_file_from_repo(
|
|
|
+ weight_path = get_file_from_repo(
|
|
|
repo_id,
|
|
|
SAFETENSORS_WEIGHTS_NAME,
|
|
|
revision=revision,
|
|
@@ -108,9 +94,9 @@ def load_peft(
|
|
|
cache_dir=cache_dir,
|
|
|
local_files_only=False,
|
|
|
)
|
|
|
- if path is None:
|
|
|
+ if weight_path is None:
|
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
- return config, load_file(path)
|
|
|
+ return config, load_file(weight_path)
|
|
|
except Exception as e:
|
|
|
- logger.warning(f"Failed to load file {SAFETENSORS_WEIGHTS_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
- time.sleep(delay)
|
|
|
+ logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
+ time.sleep(delay)
|