test_peft.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import shutil
  3. import pytest
  4. from huggingface_hub import snapshot_download
  5. from petals.utils.peft import check_peft_repository, load_peft
  6. UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
  7. SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
  8. TMP_CACHE_DIR = "tmp_cache/"
  9. def clear_dir(path_to_dir):
  10. shutil.rmtree(path_to_dir)
  11. os.mkdir(path_to_dir)
  12. def dir_empty(path_to_dir):
  13. files = os.listdir(path_to_dir)
  14. return len(files) == 0
  15. @pytest.mark.forked
  16. def test_check_peft():
  17. assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
  18. assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
  19. @pytest.mark.forked
  20. def test_load_noncached(tmpdir):
  21. clear_dir(tmpdir)
  22. with pytest.raises(Exception):
  23. load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
  24. assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
  25. load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
  26. assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
  27. @pytest.mark.forked
  28. def test_load_cached(tmpdir):
  29. clear_dir(tmpdir)
  30. snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
  31. load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
  32. @pytest.mark.forked
  33. def test_load_layer_exists(tmpdir):
  34. clear_dir(tmpdir)
  35. load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
  36. @pytest.mark.forked
  37. def test_load_layer_nonexists(tmpdir):
  38. clear_dir(tmpdir)
  39. load_peft(
  40. SAFE_PEFT_REPO,
  41. block_idx=1337,
  42. cache_dir=tmpdir,
  43. )