소스 검색

Add skeleton for peft init

artek0chumak 2 년 전
부모
커밋
e452df25cc
3개의 변경된 파일58개의 추가작업 그리고 1개의 파일을 삭제
  1. 1 1
      setup.cfg
  2. 8 0
      src/petals/utils/convert_block.py
  3. 49 0
      src/petals/utils/peft.py

+ 1 - 1
setup.cfg

@@ -46,7 +46,7 @@ install_requires =
     cpufeature>=0.2.0
     packaging>=20.9
     sentencepiece>=0.1.99
-    peft @ git+https://github.com/huggingface/peft
+    peft@git+https://github.com/huggingface/peft@5884bdbea49e5e71e2cd06ecfa484bb635063735
     safetensors>=0.3.1
 
 [options.extras_require]

+ 8 - 0
src/petals/utils/convert_block.py

@@ -10,6 +10,7 @@ import tensor_parallel as tp
 import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
+from peft import create_lora_adapter, add_adapter_to_block, load_peft
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
@@ -30,6 +31,7 @@ def convert_block(
     output_device: torch.device,
     quant_type: QuantType,
     freeze: bool = True,
+    adapters: Optional[List[str]] = None,
 ) -> tp.TensorParallel:
     """
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@@ -55,6 +57,12 @@ def convert_block(
 
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
+        
+    if adapters:
+        create_lora_adapter(block)
+        for adapter in adapters:
+            adapter_config, adapter_state_dict = load_peft(adapter)
+            add_adapter_to_block(block, adapter_config, adapter_state_dict)
 
     return block
 

+ 49 - 0
src/petals/utils/peft.py

@@ -108,3 +108,52 @@ def load_peft(
                 f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
             )
             time.sleep(delay)
+
+
+def create_lora_adapter(block):
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            lora_wrapped_child = None
+            if isinstance(child, nn.Linear):
+                bias = hasattr(target, "bias") and target.bias is not None
+                lora_wrapped_child = peft.tuners.lora.Linear(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    bias=bias,
+                )
+            elif isinstance(child, bnb.nn.Linear8bitLt):
+                kwargs = {
+                    "has_fp16_weights": child.state.has_fp16_weights,
+                    "memory_efficient_backward": child.state.memory_efficient_backward,
+                    "threshold": child.state.threshold,
+                    "index": child.index,
+                    "bias": hasattr(target, "bias") and target.bias is not None,
+                }
+                lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            elif isinstance(child, bnb.nn.Linear4bit):
+                kwargs = {
+                    "compute_dtype": child.compute_dtype,
+                    "compress_statistics": child.weight.compress_statistics,
+                    "quant_type": child.weight.quant_type,
+                    "bias": hasattr(target, "bias") and target.bias is not None,
+                }
+                lora_wrapped_child = peft.tuners.lora.Linear4bit(
+                    child_name,
+                    child.in_features,
+                    child.out_features,
+                    **kwargs,
+                )
+            if lora_wrapped_child:
+                lora_wrapped_child.active_adapter = None
+                setattr(module, child_name, lora_wrapped_child)
+                
+                
+def add_adapter_to_block(block, peft_config, peft_state_dict):
+    assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
+    pass