瀏覽代碼

Test that bitsandbytes is not imported when it's not used (#351)

We avoid importing bitsandbytes when it's not used, since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
Alexander Borzunov 2 年之前
父節點
當前提交
1a78638c02
共有 4 個文件被更改,包括 19 次插入4 次删除
  1. 1 1
      setup.cfg
  2. 0 1
      src/petals/server/server.py
  3. 16 0
      tests/test_aux_functions.py
  4. 2 2
      tests/test_sequence_manager.py

+ 1 - 1
setup.cfg

@@ -33,7 +33,7 @@ python_requires = >=3.7
 install_requires =
     torch>=1.12
     bitsandbytes==0.40.0.post4
-    accelerate>=0.16.0,<1.0.0
+    accelerate>=0.16.0,<0.21.0
     huggingface-hub>=0.11.1,<1.0.0
     tokenizers>=0.13.3
     transformers>=4.30.1,<5.0.0

+ 0 - 1
src/petals/server/server.py

@@ -30,7 +30,6 @@ from petals.server.throughput import get_dtype_name, get_server_throughput
 from petals.utils.auto_config import AutoDistributedConfig
 from petals.utils.convert_block import QuantType, check_device_balance, convert_block
 from petals.utils.disk_cache import DEFAULT_CACHE_DIR
-from petals.utils.peft import estimate_adapter_memory_per_block
 from petals.utils.version import get_compatible_model_repo
 
 logger = get_logger(__name__)

+ 16 - 0
tests/test_aux_functions.py

@@ -1,3 +1,6 @@
+import subprocess
+import sys
+
 import pytest
 import torch
 
@@ -7,6 +10,19 @@ from petals.utils.convert_block import QuantType
 from test_utils import MODEL_NAME
 
 
+def test_bnb_not_imported_when_unnecessary():
+    """
+    We avoid importing bitsandbytes when it's not used,
+    since bitsandbytes doesn't always find correct CUDA libs and may raise exceptions because of that.
+
+    If this test fails, please change your code to import bitsandbytes and/or petals.utils.peft
+    in the function's/method's code when it's actually needed instead of importing them in the beginning of the file.
+    This won't slow down the code - importing a module for the 2nd time doesn't rerun module code.
+    """
+
+    subprocess.check_call([sys.executable, "-c", "import petals, sys; assert 'bitsandbytes' not in sys.modules"])
+
+
 @pytest.mark.forked
 @pytest.mark.parametrize("tensor_parallel", [False, True])
 def test_compute_throughput(tensor_parallel: bool):

+ 2 - 2
tests/test_sequence_manager.py

@@ -25,7 +25,7 @@ def test_sequence_manager_basics(mode: str):
     block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(config.num_hidden_layers)]
     sequential = RemoteSequential(
         config,
-        sequence_manager=TestSequenceManager(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
+        sequence_manager=RemoteSequenceManagerWithChecks(config, block_uids, dht=dht, _was_shut_down=shutdown_evt),
     )
 
     sequence = sequential.sequence_manager.make_sequence(mode=mode)
@@ -43,7 +43,7 @@ def test_sequence_manager_basics(mode: str):
     assert shutdown_evt.is_set()
 
 
-class TestSequenceManager(RemoteSequenceManager):
+class RemoteSequenceManagerWithChecks(RemoteSequenceManager):
     """A sequence manager that signals if it was shut down"""
 
     def __init__(self, *args, _was_shut_down: threading.Event, **kwargs):