|
@@ -20,6 +20,18 @@ def check_peft_repository(repo_id: str) -> bool:
|
|
|
return len(list_of_files) > 0
|
|
|
|
|
|
|
|
|
+def get_adapter_from_repo(repo_id, **kwargs):
|
|
|
+ config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
|
+ if config_path is None:
|
|
|
+ raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
|
+ config = PeftConfig.from_json_file(config_path)
|
|
|
+
|
|
|
+ weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
|
|
+ if weight_path is None:
|
|
|
+ raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
+ return config, load_file(weight_path)
|
|
|
+
|
|
|
+
|
|
|
def load_peft(
|
|
|
repo_id: str,
|
|
|
*,
|
|
@@ -36,27 +48,13 @@ def load_peft(
|
|
|
|
|
|
try:
|
|
|
with allow_cache_reads(cache_dir):
|
|
|
- config_path = get_file_from_repo(
|
|
|
- repo_id,
|
|
|
- CONFIG_NAME,
|
|
|
- revision=revision,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- local_files_only=True,
|
|
|
- )
|
|
|
- assert config_path is not None
|
|
|
- config = PeftConfig.from_json_file(config_path)
|
|
|
-
|
|
|
- weight_path = get_file_from_repo(
|
|
|
+ return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
- SAFETENSORS_WEIGHTS_NAME,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
|
cache_dir=cache_dir,
|
|
|
- local_files_only=True,
|
|
|
+ local_files_only=False,
|
|
|
)
|
|
|
- assert weight_path is not None
|
|
|
- return config, load_file(weight_path)
|
|
|
except Exception:
|
|
|
logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
|
|
@@ -70,33 +68,17 @@ def load_peft(
|
|
|
|
|
|
file_size = config_file_size + wieght_file_size
|
|
|
if file_size is not None:
|
|
|
- free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
+ free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
else:
|
|
|
logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
|
|
|
|
|
- config_path = get_file_from_repo(
|
|
|
- repo_id,
|
|
|
- CONFIG_NAME,
|
|
|
- revision=revision,
|
|
|
- use_auth_token=use_auth_token,
|
|
|
- cache_dir=cache_dir,
|
|
|
- local_files_only=False,
|
|
|
- )
|
|
|
- if config_path is None:
|
|
|
- raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
|
- config = PeftConfig.from_json_file(config_path)
|
|
|
-
|
|
|
- weight_path = get_file_from_repo(
|
|
|
+ return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
- SAFETENSORS_WEIGHTS_NAME,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
|
cache_dir=cache_dir,
|
|
|
local_files_only=False,
|
|
|
)
|
|
|
- if weight_path is None:
|
|
|
- raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
- return config, load_file(weight_path)
|
|
|
except Exception as e:
|
|
|
logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
time.sleep(delay)
|