浏览代码

Add layer specific loading

artek0chumak 2 年之前
父节点
当前提交
ddd770000b
共有 2 个文件被更改,包括 42 次插入7 次删除
  1. 28 7
      src/petals/utils/peft.py
  2. 14 0
      tests/test_peft.py

+ 28 - 7
src/petals/utils/peft.py

@@ -1,10 +1,11 @@
 import time
 
-from typing import Optional
+from typing import List, Optional
 
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, hf_hub_url, get_hf_file_metadata
 from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 
@@ -20,7 +21,22 @@ def check_peft_repository(repo_id: str) -> bool:
     return len(list_of_files) > 0
 
 
-def get_adapter_from_repo(repo_id, **kwargs):
+def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt"):
+    tensors = dict()
+    is_tensors_found = dict()
+    with safe_open(filepath, framework=framework) as f:
+        for k in f.keys():
+            for layer_name in layers_name:
+                if k.startswith(layer_name):
+                    is_tensors_found[layer_name] = True
+                    tensors[k] = f.get_tensor(k)
+        for layer_name in layers_name:
+            if not is_tensors_found.get(layer_name, False):
+                logger.warning(f"There is no peft weights with prefix {layer_name}")
+        return tensors
+
+
+def get_adapter_from_repo(repo_id: str, layers_name: Optional[List[str]] = None, **kwargs):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     if config_path is None:
         raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
@@ -29,11 +45,14 @@ def get_adapter_from_repo(repo_id, **kwargs):
     weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
     if weight_path is None:
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
-    return config, load_file(weight_path)
+    if layers_name is None:
+        return config, load_file(weight_path)
+    return config, load_specific_module(layers_name, weight_path)
 
 
 def load_peft(
     repo_id: str,
+    layers_name: Optional[List[str]] = None,
     *,
     revision: Optional[str] = None,
     use_auth_token: Optional[str] = None,
@@ -50,6 +69,7 @@ def load_peft(
         with allow_cache_reads(cache_dir):
             return get_adapter_from_repo(
                 repo_id,
+                layers_name,
                 revision=revision,
                 use_auth_token=use_auth_token,
                 cache_dir=cache_dir,
@@ -63,10 +83,10 @@ def load_peft(
             with allow_cache_writes(cache_dir):
                 config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
                 config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
-                wieght_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
-                wieght_file_size = get_hf_file_metadata(wieght_url, token=use_auth_token).size
+                weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+                weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
 
-                file_size = config_file_size + wieght_file_size
+                file_size = config_file_size + weight_file_size
                 if file_size is not None:
                     free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
                 else:
@@ -74,11 +94,12 @@ def load_peft(
 
                 return get_adapter_from_repo(
                     repo_id,
+                    layers_name,
                     revision=revision,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )
         except Exception as e:
-            logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            logger.warning(f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
             time.sleep(delay)

+ 14 - 0
tests/test_peft.py

@@ -47,3 +47,17 @@ def test_load_cached(tmpdir):
     snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
     
     load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_exists(tmpdir):
+    clear_dir(tmpdir)
+    
+    load_peft(SAFE_PEFT_REPO, layers_name=["base_model.model.transformer.h.0"], cache_dir=tmpdir)
+
+
+@pytest.mark.forked
+def test_load_layer_nonexists(tmpdir):
+    clear_dir(tmpdir)
+    
+    load_peft(SAFE_PEFT_REPO, layers_name=["base_model.model.transformer.h.0", "base_model.model.transformer.h.100"], cache_dir=tmpdir)