Browse Source

Add adapters loading

artek0chumak 2 năm trước cách đây
mục cha
commit
30e3f4a6b4

+ 5 - 1
.github/workflows/run-tests.yaml

@@ -33,6 +33,7 @@ jobs:
         run: |
           export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
+          export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
@@ -58,11 +59,14 @@ jobs:
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
           SERVER4_PID=$!
 
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:24             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1             --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server5.log &
+          SERVER5_PID=$!
+
           tail -n 100 -f server*.log &
           LOGGER_PID=$!
           sleep 30  # wait for servers to download layers
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
 
           pytest tests --durations=0 --durations-min=1.0 -v
 

+ 2 - 0
src/petals/cli/run_server.py

@@ -146,6 +146,8 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
+    
+    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
 
     # fmt:on
     args = vars(parser.parse_args())

+ 11 - 1
src/petals/server/server.py

@@ -81,6 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_auto_relay: bool = True,
+        adapters: Optional[List[str]] = None,
         **kwargs,
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -217,6 +218,8 @@ class Server:
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
+        
+        self.adapters = adapters
 
         self.stop = threading.Event()
 
@@ -291,6 +294,7 @@ class Server:
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
+                adapters=self.adapters,
                 start=True,
             )
             try:
@@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
+        adapters: Optional[List[str]] = None,
         **kwargs,
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
@@ -415,7 +420,12 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
+                block = convert_block(
+                    block, block_index, block_config, tensor_parallel_devices, device, quant_type, adapters=adapters, freeze=True,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
+                )
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     block,

+ 12 - 6
src/petals/utils/convert_block.py

@@ -4,13 +4,13 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 import os
 import re
 from enum import Enum
-from typing import Sequence
+from typing import List, Optional, Sequence
 
 import tensor_parallel as tp
 import torch
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from peft import create_lora_adapter, add_adapter_to_block, load_peft
+from petals.utils.peft import create_lora_adapter, add_adapter_to_block, load_peft
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 
@@ -26,12 +26,14 @@ class QuantType(Enum):
 
 def convert_block(
     block: nn.Module,
+    block_index: int,
     config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     quant_type: QuantType,
     freeze: bool = True,
     adapters: Optional[List[str]] = None,
+    **kwargs,
 ) -> tp.TensorParallel:
     """
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@@ -57,12 +59,16 @@ def convert_block(
 
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
-        
+
     if adapters:
         create_lora_adapter(block)
-        for adapter in adapters:
-            adapter_config, adapter_state_dict = load_peft(adapter)
-            add_adapter_to_block(block, adapter_config, adapter_state_dict)
+        for adapter_name in adapters:
+            adapter_config, adapter_state_dict = load_peft(
+                adapter_name,
+                block_idx=block_index,
+                **kwargs,
+            )
+            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
 
     return block
 

+ 58 - 24
src/petals/utils/peft.py

@@ -1,9 +1,14 @@
+import re
 import time
 from typing import List, Optional
 
+import torch.nn as nn
+import bitsandbytes as bnb
+
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
-from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from peft.tuners import lora
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
 from safetensors import safe_open
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
@@ -19,23 +24,22 @@ def check_peft_repository(repo_id: str) -> bool:
     return len(list_of_files) > 0
 
 
-def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt", device: Optional[int] = None):
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
     tensors = dict()
     is_tensors_found = dict()
+    common_layer_patter_re = ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
     with safe_open(filepath, framework=framework, device=device) as f:
         for k in f.keys():
-            for layer_name in layers_name:
-                if k.startswith(layer_name):
-                    is_tensors_found[layer_name] = True
-                    tensors[k] = f.get_tensor(k)
-        for layer_name in layers_name:
-            if not is_tensors_found.get(layer_name, False):
-                logger.warning(f"There is no peft weights with prefix {layer_name}")
+            if re.match(common_layer_patter_re, k):
+                is_tensors_found[block_idx] = True
+                tensors[k] = f.get_tensor(k)
+        if not is_tensors_found.get(block_idx, False):
+            logger.warning(f"There is no peft weights for block {block_idx}")
         return tensors
 
 
 def get_adapter_from_repo(
-    repo_id: str, layers_name: Optional[List[str]] = None, device: Optional[int] = None, **kwargs
+    repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs
 ):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     if config_path is None:
@@ -45,14 +49,14 @@ def get_adapter_from_repo(
     weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
     if weight_path is None:
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
-    if layers_name is None:
+    if block_idx is None:
         return config, load_file(weight_path)
-    return config, load_specific_module(layers_name, weight_path, device=device)
+    return config, load_specific_module(block_idx, weight_path, device=device)
 
 
 def load_peft(
     repo_id: str,
-    layers_name: Optional[List[str]] = None,
+    block_idx: Optional[int] = None,
     device: Optional[int] = None,
     *,
     revision: Optional[str] = None,
@@ -70,7 +74,7 @@ def load_peft(
         with allow_cache_reads(cache_dir):
             return get_adapter_from_repo(
                 repo_id,
-                layers_name,
+                block_idx,
                 device,
                 revision=revision,
                 use_auth_token=use_auth_token,
@@ -96,7 +100,7 @@ def load_peft(
 
                 return get_adapter_from_repo(
                     repo_id,
-                    layers_name,
+                    block_idx,
                     device,
                     revision=revision,
                     use_auth_token=use_auth_token,
@@ -115,8 +119,8 @@ def create_lora_adapter(block):
         for child_name, child in module.named_children():
             lora_wrapped_child = None
             if isinstance(child, nn.Linear):
-                bias = hasattr(target, "bias") and target.bias is not None
-                lora_wrapped_child = peft.tuners.lora.Linear(
+                bias = hasattr(child, "bias") and child.bias is not None
+                lora_wrapped_child = lora.Linear(
                     child_name,
                     child.in_features,
                     child.out_features,
@@ -128,9 +132,9 @@ def create_lora_adapter(block):
                     "memory_efficient_backward": child.state.memory_efficient_backward,
                     "threshold": child.state.threshold,
                     "index": child.index,
-                    "bias": hasattr(target, "bias") and target.bias is not None,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
                 }
-                lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
+                lora_wrapped_child = lora.Linear8bitLt(
                     child_name,
                     child.in_features,
                     child.out_features,
@@ -141,9 +145,9 @@ def create_lora_adapter(block):
                     "compute_dtype": child.compute_dtype,
                     "compress_statistics": child.weight.compress_statistics,
                     "quant_type": child.weight.quant_type,
-                    "bias": hasattr(target, "bias") and target.bias is not None,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
                 }
-                lora_wrapped_child = peft.tuners.lora.Linear4bit(
+                lora_wrapped_child = lora.Linear4bit(
                     child_name,
                     child.in_features,
                     child.out_features,
@@ -151,9 +155,39 @@ def create_lora_adapter(block):
                 )
             if lora_wrapped_child:
                 lora_wrapped_child.active_adapter = None
+                for p in lora_wrapped_child.parameters():
+                    p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
                 
                 
-def add_adapter_to_block(block, peft_config, peft_state_dict):
-    assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
-    pass
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
+    assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
+                continue
+
+            if child_name in peft_config["target_modules"] or (isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name)):
+                is_lora_a_loaded = False
+                is_lora_b_loaded = False
+                for peft_key in peft_state_dict:
+                    if adapter_name not in child.lora_A:
+                        child.update_layer(
+                            adapter_name,
+                            peft_config["r"],
+                            peft_config["lora_alpha"],
+                            peft_config["lora_dropout"],
+                            peft_config["init_lora_weights"],
+                        )
+                        for p in child.parameters():
+                            p.requires_grad = False
+
+                    if "lora_A" in peft_key:
+                        child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] * child.scaling[adapter_name]
+                        is_lora_a_loaded = True
+                    elif "lora_B" in peft_key:
+                        child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_b_loaded = True
+                        
+                if is_lora_a_loaded and is_lora_b_loaded:
+                    logger.info(f"Loading {adapter_name} for block {block_index} is ended successfully")