Browse Source

Add first functional code

artek0chumak 2 years ago
parent
commit
6563bf1750
3 changed files with 132 additions and 16 deletions
  1. 2 0
      setup.cfg
  2. 116 0
      src/petals/utils/peft.py
  3. 14 16
      tests/test_peft.py

+ 2 - 0
setup.cfg

@@ -46,6 +46,8 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
+    peft @ git+https://github.com/huggingface/peft
+    safetensors>=0.3.1
 
 [options.extras_require]
 dev =

+ 116 - 0
src/petals/utils/peft.py

@@ -0,0 +1,116 @@
+import time
+
+from typing import Optional
+
+from hivemind.utils.logging import get_logger
+from huggingface_hub import HfFileSystem, hf_hub_url, get_hf_file_metadata
+from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from safetensors.torch import load_file
+from transformers.utils import get_file_from_repo
+
+from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
+
+
+logger = get_logger(__name__)
+
+
+def check_peft_repository(repo_id: str) -> bool:
+    fs = HfFileSystem()
+    list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
+    return len(list_of_files) > 0
+
+
+def load_peft(
+    repo_id: str,
+    *,
+    revision: Optional[str] = None,
+    use_auth_token: Optional[str] = None,
+    cache_dir: str,
+    max_disk_space: Optional[int] = None,
+    delay: float = 30
+):
+    # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
+
+    if not check_peft_repository(repo_id):
+        raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
+    
+    try:
+        with allow_cache_reads(cache_dir):
+            path = get_file_from_repo(
+                repo_id,
+                CONFIG_NAME,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            config = PeftConfig.from_json_file(path)
+    except Exception:
+        logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
+
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
+                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                if file_size is not None:
+                    free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
+
+                path = get_file_from_repo(
+                    repo_id,
+                    CONFIG_NAME,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+                if path is None:
+                    raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
+                config = PeftConfig.from_json_file(path)
+                break
+        except Exception as e:
+            logger.warning(f"Failed to load file {CONFIG_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)
+    
+    try:
+        with allow_cache_reads(cache_dir):
+            path = get_file_from_repo(
+                repo_id,
+                SAFETENSORS_WEIGHTS_NAME,
+                revision=revision,
+                use_auth_token=use_auth_token,
+                cache_dir=cache_dir,
+                local_files_only=True,
+            )
+            if path is not None:
+                return config, load_file(path)
+    except Exception:
+        logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
+        
+    # If not found, ensure that we have enough disk space to download them (maybe remove something)
+    while True:
+        try:
+            with allow_cache_writes(cache_dir):
+                url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
+                file_size = get_hf_file_metadata(url, token=use_auth_token).size
+                if file_size is not None:
+                    free_disk_space_for(repo_id, file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
+                else:
+                    logger.warning(f"Failed to fetch size of weight from peft repo {repo_id}")
+
+                path = get_file_from_repo(
+                    repo_id,
+                    SAFETENSORS_WEIGHTS_NAME,
+                    revision=revision,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    local_files_only=False,
+                )
+                if path is None:
+                    raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
+                return config, load_file(path)
+        except Exception as e:
+            logger.warning(f"Failed to load file {SAFETENSORS_WEIGHTS_NAME} from HF Hub (retry in {delay:.0f} sec)", exc_info=True)
+            time.sleep(delay)

+ 14 - 16
tests/test_peft.py

@@ -2,9 +2,9 @@ import os
 import pytest
 import shutil
 
-from huggingface_hub import hf_hub_download
+from huggingface_hub import snapshot_download
 
-from petals.utils import check_peft_repository, load_peft
+from petals.utils.peft import check_peft_repository, load_peft
 
 
 NOSAFE_PEFT_REPO = "timdettmers/guanaco-7b"
@@ -19,7 +19,7 @@ def clear_dir(path_to_dir):
 
 def dir_empty(path_to_dir):
     files = os.listdir(path_to_dir)
-    return files.empty()
+    return len(files) == 0
 
 
 @pytest.mark.forked
@@ -29,23 +29,21 @@ def test_check_peft():
 
 
 @pytest.mark.forked
-def test_load_noncached():
-    clear_dir(TMP_CACHE_DIR)
+def test_load_noncached(tmpdir):
+    clear_dir(tmpdir)
     with pytest.raises(Exception):
-        load_peft(NOSAFE_PEFT_REPO, cache_dir=TMP_CACHE_DIR)
+        load_peft(NOSAFE_PEFT_REPO, cache_dir=tmpdir)
         
-    assert dir_empty(TMP_CACHE_DIR), "NOSAFE_PEFT_REPO is loaded"
+    assert dir_empty(tmpdir), "NOSAFE_PEFT_REPO is loaded"
 
-    status = load_peft(SAFE_PEFT_REPO, cache_dir=TMP_CACHE_DIR)
-    
-    assert status, "PEFT is not loaded"
-    assert not dir_empty(TMP_CACHE_DIR), "SAFE_PEFT_REPO is not loaded"
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
+
+    assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
 
 
 @pytest.mark.forked
-def test_load_cached():
-    clear_dir(TMP_CACHE_DIR)
-    hf_hub_download(SAFE_PEFT_REPO, cache_dir=TMP_CACHE_DIR)
+def test_load_cached(tmpdir):
+    clear_dir(tmpdir)
+    snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
     
-    status = load_peft(SAFE_PEFT_REPO, cache_dir=TMP_CACHE_DIR)
-    assert status, "PEFT is not loaded"
+    load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)