|
@@ -19,10 +19,10 @@ def check_peft_repository(repo_id: str) -> bool:
|
|
return len(list_of_files) > 0
|
|
return len(list_of_files) > 0
|
|
|
|
|
|
|
|
|
|
-def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt"):
|
|
|
|
|
|
+def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt", device: Optional[int] = None):
|
|
tensors = dict()
|
|
tensors = dict()
|
|
is_tensors_found = dict()
|
|
is_tensors_found = dict()
|
|
- with safe_open(filepath, framework=framework) as f:
|
|
|
|
|
|
+ with safe_open(filepath, framework=framework, device=device) as f:
|
|
for k in f.keys():
|
|
for k in f.keys():
|
|
for layer_name in layers_name:
|
|
for layer_name in layers_name:
|
|
if k.startswith(layer_name):
|
|
if k.startswith(layer_name):
|
|
@@ -34,7 +34,9 @@ def load_specific_module(layers_name: List[str], filepath: str, framework: str =
|
|
return tensors
|
|
return tensors
|
|
|
|
|
|
|
|
|
|
-def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None, **kwargs):
|
|
|
|
|
|
+def get_adapter_from_repo(
|
|
|
|
+ repo_id: str, layers_name: Optional[List[str]] = None, device: Optional[int] = None, **kwargs
|
|
|
|
+):
|
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
if config_path is None:
|
|
if config_path is None:
|
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
@@ -45,12 +47,13 @@ def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None,
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
if layers_name is None:
|
|
if layers_name is None:
|
|
return config, load_file(weight_path)
|
|
return config, load_file(weight_path)
|
|
- return config, load_specific_module(layers_name, weight_path)
|
|
|
|
|
|
+ return config, load_specific_module(layers_name, weight_path, device=device)
|
|
|
|
|
|
|
|
|
|
def load_peft(
|
|
def load_peft(
|
|
repo_id: str,
|
|
repo_id: str,
|
|
layers_name: Optional[List[str]] = None,
|
|
layers_name: Optional[List[str]] = None,
|
|
|
|
+ device: Optional[int] = None,
|
|
*,
|
|
*,
|
|
revision: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
use_auth_token: Optional[str] = None,
|
|
use_auth_token: Optional[str] = None,
|
|
@@ -68,6 +71,7 @@ def load_peft(
|
|
return get_adapter_from_repo(
|
|
return get_adapter_from_repo(
|
|
repo_id,
|
|
repo_id,
|
|
layers_name,
|
|
layers_name,
|
|
|
|
+ device,
|
|
revision=revision,
|
|
revision=revision,
|
|
use_auth_token=use_auth_token,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
cache_dir=cache_dir,
|
|
@@ -93,6 +97,7 @@ def load_peft(
|
|
return get_adapter_from_repo(
|
|
return get_adapter_from_repo(
|
|
repo_id,
|
|
repo_id,
|
|
layers_name,
|
|
layers_name,
|
|
|
|
+ device,
|
|
revision=revision,
|
|
revision=revision,
|
|
use_auth_token=use_auth_token,
|
|
use_auth_token=use_auth_token,
|
|
cache_dir=cache_dir,
|
|
cache_dir=cache_dir,
|