|
@@ -1,10 +1,11 @@
|
|
|
import time
|
|
|
|
|
|
-from typing import Optional
|
|
|
+from typing import List, Optional
|
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from huggingface_hub import HfFileSystem, hf_hub_url, get_hf_file_metadata
|
|
|
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
|
+from safetensors import safe_open
|
|
|
from safetensors.torch import load_file
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
|
@@ -20,7 +21,22 @@ def check_peft_repository(repo_id: str) -> bool:
|
|
|
return len(list_of_files) > 0
|
|
|
|
|
|
|
|
|
-def get_adapter_from_repo(repo_id, **kwargs):
|
|
|
+def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt"):
|
|
|
+ tensors = dict()
|
|
|
+ is_tensors_found = dict()
|
|
|
+ with safe_open(filepath, framework=framework) as f:
|
|
|
+ for k in f.keys():
|
|
|
+ for layer_name in layers_name:
|
|
|
+ if k.startswith(layer_name):
|
|
|
+ is_tensors_found[layer_name] = True
|
|
|
+ tensors[k] = f.get_tensor(k)
|
|
|
+ for layer_name in layers_name:
|
|
|
+ if not is_tensors_found.get(layer_name, False):
|
|
|
+ logger.warning(f"There is no peft weights with prefix {layer_name}")
|
|
|
+ return tensors
|
|
|
+
|
|
|
+
|
|
|
+def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None, **kwargs):
|
|
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
|
if config_path is None:
|
|
|
raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
@@ -29,11 +45,14 @@ def get_adapter_from_repo(repo_id, **kwargs):
|
|
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
|
|
if weight_path is None:
|
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
- return config, load_file(weight_path)
|
|
|
+ if layers_name is None:
|
|
|
+ return config, load_file(weight_path)
|
|
|
+ return config, load_specific_module(layers_name, weight_path)
|
|
|
|
|
|
|
|
|
def load_peft(
|
|
|
repo_id: str,
|
|
|
+ layers_name: Optional[List[str]] = None,
|
|
|
*,
|
|
|
revision: Optional[str] = None,
|
|
|
use_auth_token: Optional[str] = None,
|
|
@@ -50,6 +69,7 @@ def load_peft(
|
|
|
with allow_cache_reads(cache_dir):
|
|
|
return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
+ layers_name,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
|
cache_dir=cache_dir,
|
|
@@ -63,10 +83,10 @@ def load_peft(
|
|
|
with allow_cache_writes(cache_dir):
|
|
|
config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
|
|
config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
|
|
- wieght_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
|
- wieght_file_size = get_hf_file_metadata(wieght_url, token=use_auth_token).size
|
|
|
+ weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
|
+ weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
|
|
|
|
|
- file_size = config_file_size + wieght_file_size
|
|
|
+ file_size = config_file_size + weight_file_size
|
|
|
if file_size is not None:
|
|
|
free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
else:
|
|
@@ -74,11 +94,12 @@ def load_peft(
|
|
|
|
|
|
return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
+ layers_name,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
|
cache_dir=cache_dir,
|
|
|
local_files_only=False,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
- logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
+ logger.warning(f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
|
|
|
time.sleep(delay)
|