ソースを参照

Update peft dependency, fix initialization and inference with new peft (#557)

* Make fixes

* lib number

* Fix inference without adapter

* Fix trainability

* Fix versions

* style

* Update comments

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>

* Remove unnesc todo

---------

Co-authored-by: Max Ryabinin <mryabinin0@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Artem Chumachenko 1 年間 前
コミット
f1e1b051d0
3 ファイル変更44 行追加49 行削除
  1. 1 1
      setup.cfg
  2. 1 1
      src/petals/utils/convert_block.py
  3. 42 47
      src/petals/utils/peft.py

+ 1 - 1
setup.cfg

@@ -47,7 +47,7 @@ install_requires =
     cpufeature>=0.2.0; platform_machine == "x86_64"
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft==0.5.0
+    peft==0.8.2
     safetensors>=0.3.1
     Dijkstar>=2.6.0
     numpy<2

+ 1 - 1
src/petals/utils/convert_block.py

@@ -61,7 +61,7 @@ def convert_block(
     if adapters:
         from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
 
-        create_lora_adapter(block, quant_type=quant_type)
+        create_lora_adapter(block)
         for adapter_name in adapters:
             adapter_config, adapter_state_dict = load_peft(
                 adapter_name,

+ 42 - 47
src/petals/utils/peft.py

@@ -1,7 +1,7 @@
 import contextlib
 import re
 import time
-from typing import Optional, Sequence, Union
+from typing import List, Optional, Sequence, Union
 
 import bitsandbytes as bnb
 import torch
@@ -12,7 +12,7 @@ from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from peft.config import PeftConfig
 from peft.tuners import lora
-from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
+from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
 from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
@@ -25,6 +25,9 @@ from petals.utils.misc import get_size_in_bytes
 logger = get_logger(__name__)
 
 
+COMMON_LAYERS_PATTERN = ["layers", "h", "block", "blocks", "layer"]
+
+
 def check_peft_repository(repo_id: str) -> bool:
     return HfFileSystem().exists(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}")
 
@@ -151,6 +154,18 @@ class AdapterContextMixin:
     def active_adapter(self, value: Optional[str]):
         assert value == self.ADAPTER_NOT_SET, "active adapter can only be changed via .using_adapter" ""
 
+    @property
+    def active_adapters(self):
+        return [self._context_active_adapter]
+
+    def set_adapter(self, adapter_names) -> None:
+        """
+        In PEFT, this function makes the adapter trainable. However, in Petals environment this is not possible now. Thus,
+        this code removes this functionality.
+        Link to peft code: https://github.com/huggingface/peft/blob/98f4db2c7990ef9c879a0e1da9a28a19a04701ef/src/peft/tuners/tuners_utils.py#L463
+        """
+        pass
+
 
 using_adapter = AdapterContextMixin.using_adapter
 
@@ -158,60 +173,39 @@ using_adapter = AdapterContextMixin.using_adapter
 class LoraLinear(AdapterContextMixin, lora.Linear):
     """LoRA linear layer that uses adapter selected via using_adapter"""
 
+    def __init__(self, base_layer, adapter_name: str):
+        nn.Module.__init__(self)
+        lora.LoraLayer.__init__(self, base_layer)
+
+        self._active_adapter = adapter_name
+        self.is_target_conv_1d_layer = False
+
 
-class LoraLinear8bitLt(AdapterContextMixin, lora.Linear8bitLt):
+class LoraLinear8bitLt(LoraLinear, lora.Linear8bitLt):
     """LoRA linear 8-bit with outliers that uses adapter selected via using_adapter"""
 
 
-class LoraLinear4bit(AdapterContextMixin, lora.Linear4bit):
+class LoraLinear4bit(LoraLinear, lora.Linear4bit):
     """LoRA linear 4-bit that uses adapter selected via using_adapter"""
 
 
-def create_lora_adapter(block, quant_type: QuantType):
-    for _, module in block.named_modules():
+def create_lora_adapter(block):
+    for module_name, module in block.named_modules():
+        if isinstance(module, LoraLinear):
+            continue
         for child_name, child in module.named_children():
-            lora_wrapped_child = None
-            if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
-                continue
-            if quant_type == QuantType.INT8:
-                kwargs = {
-                    "has_fp16_weights": False,
-                    "threshold": 6.0,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear8bitLt(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-            elif quant_type == QuantType.NF4:
-                kwargs = {
-                    "compress_statistics": True,
-                    "quant_type": "nf4",
-                    "blocksize": 64,
-                    "bias": hasattr(child, "bias") and child.bias is not None,
-                }
-                lora_wrapped_child = LoraLinear4bit(
-                    AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    **kwargs,
-                )
-                lora_wrapped_child.compute_dtype = child.compute_dtype
-            else:
-                bias = hasattr(child, "bias") and child.bias is not None
-                lora_wrapped_child = LoraLinear(
+            lora_class = None
+            if isinstance(child, nn.Linear):
+                lora_class = LoraLinear
+            elif isinstance(child, bnb.nn.Linear8bitLt):
+                lora_class = LoraLinear8bitLt
+            elif isinstance(child, bnb.nn.Linear4bit):
+                lora_class = LoraLinear4bit
+            if lora_class:
+                lora_wrapped_child = lora_class(
+                    child,
                     AdapterContextMixin.ADAPTER_NOT_SET,
-                    child.in_features,
-                    child.out_features,
-                    bias=bias,
                 )
-            if lora_wrapped_child:
-                lora_wrapped_child.weight = child.weight
-                lora_wrapped_child.bias = child.bias
-                for p in lora_wrapped_child.parameters():
-                    p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
 
 
@@ -240,6 +234,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
                             adapter_name,
                             peft_config["r"],
                             peft_config["lora_alpha"],
+                            use_rslora=peft_config.get("use_rslora", False),
                             lora_dropout=peft_config["lora_dropout"],
                             init_lora_weights=peft_config["init_lora_weights"],
                         )
@@ -275,7 +270,7 @@ def estimate_adapter_memory_per_block(
     with init_empty_weights(include_buffers=True):
         block = get_model_block(block_config)
         base_block_parameters = sum(p.numel() for p in block.parameters())
-        create_lora_adapter(block, quant_type=QuantType.NONE)
+        create_lora_adapter(block)
 
         for adapter in adapters:
             peft_config, peft_state_dict = load_peft(adapter, block_idx=0, **load_peft_kwargs)