artek0chumak 2 年之前
父节点
当前提交
92612aebed
共有 2 个文件被更改,包括 16 次插入13 次删除
  1. 6 6
      src/petals/utils/peft.py
  2. 10 7
      tests/test_peft.py

+ 6 - 6
src/petals/utils/peft.py

@@ -1,9 +1,8 @@
 import time
-
 from typing import List, Optional
 
 from hivemind.utils.logging import get_logger
-from huggingface_hub import HfFileSystem, hf_hub_url, get_hf_file_metadata
+from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
 from safetensors import safe_open
 from safetensors.torch import load_file
@@ -11,7 +10,6 @@ from transformers.utils import get_file_from_repo
 
 from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
 
-
 logger = get_logger(__name__)
 
 
@@ -58,13 +56,13 @@ def load_peft(
     use_auth_token: Optional[str] = None,
     cache_dir: str,
     max_disk_space: Optional[int] = None,
-    delay: float = 30
+    delay: float = 30,
 ):
     # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
 
     if not check_peft_repository(repo_id):
         raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
-    
+
     try:
         with allow_cache_reads(cache_dir):
             return get_adapter_from_repo(
@@ -101,5 +99,7 @@ def load_peft(
                     local_files_only=False,
                 )
         except Exception as e:
-            logger.warning(f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            logger.warning(
+                f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
+            )
             time.sleep(delay)

+ 10 - 7
tests/test_peft.py

@@ -1,12 +1,11 @@
 import os
-import pytest
 import shutil
 
+import pytest
 from huggingface_hub import snapshot_download
 
 from petals.utils.peft import check_peft_repository, load_peft
 
-
 UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
 SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
 TMP_CACHE_DIR = "tmp_cache/"
@@ -33,7 +32,7 @@ def test_load_noncached(tmpdir):
     clear_dir(tmpdir)
     with pytest.raises(Exception):
         load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
-        
+
     assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
 
     load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
@@ -45,19 +44,23 @@ def test_load_noncached(tmpdir):
 def test_load_cached(tmpdir):
     clear_dir(tmpdir)
     snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
-    
+
     load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
 
 
 @pytest.mark.forked
 def test_load_layer_exists(tmpdir):
     clear_dir(tmpdir)
-    
+
     load_peft(SAFE_PEFT_REPO, layers_name=["base_model.model.transformer.h.0"], cache_dir=tmpdir)
 
 
 @pytest.mark.forked
 def test_load_layer_nonexists(tmpdir):
     clear_dir(tmpdir)
-    
-    load_peft(SAFE_PEFT_REPO, layers_name=["base_model.model.transformer.h.0", "base_model.model.transformer.h.100"], cache_dir=tmpdir)
+
+    load_peft(
+        SAFE_PEFT_REPO,
+        layers_name=["base_model.model.transformer.h.0", "base_model.model.transformer.h.100"],
+        cache_dir=tmpdir,
+    )