Forráskód Böngészése

Fix long downloading

artek0chumak 2 éve
szülő
commit
c02ae4c9ba
2 módosított fájl, 21 hozzáadás és 39 törlés
  1. 16 34
      src/petals/utils/peft.py
  2. 5 5
      tests/test_peft.py

+ 16 - 34
src/petals/utils/peft.py

@@ -20,6 +20,18 @@ def check_peft_repository(repo_id: str) -> bool:
     return len(list_of_files) > 0
 
 
+def get_adapter_from_repo(repo_id, **kwargs):
+    config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
+    if config_path is None:
+        raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
+    config = PeftConfig.from_json_file(config_path)
+
+    weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
+    if weight_path is None:
+        raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
+    return config, load_file(weight_path)
+
+
 def load_peft(
     repo_id: str,
     *,
@@ -36,27 +48,13 @@ def load_peft(
     
     try:
         with allow_cache_reads(cache_dir):
-            config_path = get_file_from_repo(
-                repo_id,
-                CONFIG_NAME,
-                revision=revision,
-                use_auth_token=use_auth_token,
-                cache_dir=cache_dir,
-                local_files_only=True,
-            )
-            assert config_path is not None
-            config = PeftConfig.from_json_file(config_path)
-
-            weight_path = get_file_from_repo(
+            return get_adapter_from_repo(
                 repo_id,
-                SAFETENSORS_WEIGHTS_NAME,
                 revision=revision,
                 use_auth_token=use_auth_token,
                 cache_dir=cache_dir,
-                local_files_only=True,
+                local_files_only=False,
             )
-            assert weight_path is not None
-            return config, load_file(weight_path)
     except Exception:
         logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
 
@@ -70,33 +68,17 @@ def load_peft(
 
                 file_size = config_file_size + wieght_file_size
                 if file_size is not None:
-                    free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                    free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
                 else:
                     logger.warning(f"Failed to fetch size from peft repo {repo_id}")
 
-                config_path = get_file_from_repo(
-                    repo_id,
-                    CONFIG_NAME,
-                    revision=revision,
-                    use_auth_token=use_auth_token,
-                    cache_dir=cache_dir,
-                    local_files_only=False,
-                )
-                if config_path is None:
-                    raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
-                config = PeftConfig.from_json_file(config_path)
-
-                weight_path = get_file_from_repo(
+                return get_adapter_from_repo(
                     repo_id,
-                    SAFETENSORS_WEIGHTS_NAME,
                     revision=revision,
                     use_auth_token=use_auth_token,
                     cache_dir=cache_dir,
                     local_files_only=False,
                 )
-                if weight_path is None:
-                    raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
-                return config, load_file(weight_path)
         except Exception as e:
             logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
             time.sleep(delay)

+ 5 - 5
tests/test_peft.py

@@ -7,8 +7,8 @@ from huggingface_hub import snapshot_download
 from petals.utils.peft import check_peft_repository, load_peft
 
 
-NOSAFE_PEFT_REPO = "timdettmers/guanaco-7b"
-SAFE_PEFT_REPO = "artek0chumak/guanaco-7b"
+UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
+SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
 TMP_CACHE_DIR = "tmp_cache/"
 
 
@@ -24,7 +24,7 @@ def dir_empty(path_to_dir):
 
 @pytest.mark.forked
 def test_check_peft():
-    assert not check_peft_repository(NOSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
+    assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
     assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
 
 
@@ -32,9 +32,9 @@ def test_check_peft():
 def test_load_noncached(tmpdir):
     clear_dir(tmpdir)
     with pytest.raises(Exception):
-        load_peft(NOSAFE_PEFT_REPO, cache_dir=tmpdir)
+        load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
         
-    assert dir_empty(tmpdir), "NOSAFE_PEFT_REPO is loaded"
+    assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
 
     load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)