5
0

test_peft.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import os
  2. import shutil
  3. import pytest
  4. from huggingface_hub import snapshot_download
  5. pytest.skip("LoRA adapters are not supported on AMD GPUs", allow_module_level=True)
  6. from petals.utils.peft import check_peft_repository, load_peft
  7. UNSAFE_PEFT_REPO = "artek0chumak/bloom-560m-unsafe-peft"
  8. SAFE_PEFT_REPO = "artek0chumak/bloom-560m-safe-peft"
  9. TMP_CACHE_DIR = "tmp_cache/"
  10. def clear_dir(path_to_dir):
  11. shutil.rmtree(path_to_dir)
  12. os.mkdir(path_to_dir)
  13. def dir_empty(path_to_dir):
  14. files = os.listdir(path_to_dir)
  15. return len(files) == 0
  16. @pytest.mark.forked
  17. def test_check_peft():
  18. assert not check_peft_repository(UNSAFE_PEFT_REPO), "NOSAFE_PEFT_REPO is safe to load."
  19. assert check_peft_repository(SAFE_PEFT_REPO), "SAFE_PEFT_REPO is not safe to load."
  20. @pytest.mark.forked
  21. def test_load_noncached(tmpdir):
  22. clear_dir(tmpdir)
  23. with pytest.raises(Exception):
  24. load_peft(UNSAFE_PEFT_REPO, cache_dir=tmpdir)
  25. assert dir_empty(tmpdir), "UNSAFE_PEFT_REPO is loaded"
  26. load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
  27. assert not dir_empty(tmpdir), "SAFE_PEFT_REPO is not loaded"
  28. @pytest.mark.forked
  29. def test_load_cached(tmpdir):
  30. clear_dir(tmpdir)
  31. snapshot_download(SAFE_PEFT_REPO, cache_dir=tmpdir)
  32. load_peft(SAFE_PEFT_REPO, cache_dir=tmpdir)
  33. @pytest.mark.forked
  34. def test_load_layer_exists(tmpdir):
  35. clear_dir(tmpdir)
  36. load_peft(SAFE_PEFT_REPO, block_idx=2, cache_dir=tmpdir)
  37. @pytest.mark.forked
  38. def test_load_layer_nonexists(tmpdir):
  39. clear_dir(tmpdir)
  40. load_peft(
  41. SAFE_PEFT_REPO,
  42. block_idx=1337,
  43. cache_dir=tmpdir,
  44. )