Преглед на файлове

Import petals.utils.peft only when needed to avoid unnecessary import of bitsandbytes (#345)

The motivation is the same as in #180.
Alexander Borzunov преди 2 години
родител
ревизия
43acfe52a7
променени са 2 файла, в които са добавени 9 реда и са изтрити 3 реда
  1. 5 2
      src/petals/server/backend.py
  2. 4 1
      src/petals/utils/convert_block.py

+ 5 - 2
src/petals/server/backend.py

@@ -4,7 +4,6 @@ from collections import Counter
 from itertools import chain
 from typing import Any, Dict, Optional, Sequence, Tuple, Union
 
-import peft
 import torch
 from hivemind import BatchTensorDescriptor, TensorDescriptor
 from hivemind.moe.expert_uid import ExpertUID
@@ -156,9 +155,13 @@ class TransformerBackend(ModuleBackend):
 
     def load_adapter_(self, active_adapter: Optional[str] = None) -> bool:
         """Activate a given adapter set if available. Return True if available (or no adapter), False if missing"""
+
+        # Import petals.utils.peft only when necessary to avoid importing bitsandbytes
+        from peft.tuners.lora import Linear, Linear4bit, Linear8bitLt
+
         adapter_was_loaded = False
         for layer in self.module.modules():  # select adapter set -- leave empty string for no adapter
-            if isinstance(layer, (peft.tuners.lora.Linear, peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
+            if isinstance(layer, (Linear, Linear4bit, Linear8bitLt)):
                 layer.active_adapter = active_adapter  # empty string for no adapter
                 if active_adapter in layer.lora_A.keys():
                     adapter_was_loaded = True

+ 4 - 1
src/petals/utils/convert_block.py

@@ -13,7 +13,6 @@ from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
 from petals.utils.misc import QuantType
-from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
 use_hivemind_log_handler("in_root_logger")
 logger = get_logger(__name__)
@@ -56,6 +55,10 @@ def convert_block(
         shard.to(device)
 
     if adapters:
+        # Import petals.utils.peft only when necessary to avoid importing bitsandbytes
+        os.environ["BITSANDBYTES_NOWELCOME"] = "1"
+        from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
+
         create_lora_adapter(block, quant_type=quant_type)
         for adapter_name in adapters:
             adapter_config, adapter_state_dict = load_peft(