Browse Source

Add adapters loading

artek0chumak 2 năm trước cách đây
mục cha
commit
30e3f4a6b4

+ 5 - 1
.github/workflows/run-tests.yaml

@@ -33,6 +33,7 @@ jobs:
         run: |
         run: |
           export MODEL_NAME=bigscience/bloom-560m
           export MODEL_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
           export REF_NAME=bigscience/bloom-560m
+          export ADAPTER_NAME=artek0chumak/bloom-560m-safe-peft
 
 
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
           python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 \
@@ -58,11 +59,14 @@ jobs:
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
             --initial_peers $INITIAL_PEERS --throughput 1 --torch_dtype float32 --tensor_parallel_devices cpu cpu &> server4.log &
           SERVER4_PID=$!
           SERVER4_PID=$!
 
 
+          python -m petals.cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:24             --new_swarm --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1             --torch_dtype float32 --compression NONE --attn_cache_tokens 2048 --adapters $ADAPTER_NAME &> server5.log &
+          SERVER5_PID=$!
+
           tail -n 100 -f server*.log &
           tail -n 100 -f server*.log &
           LOGGER_PID=$!
           LOGGER_PID=$!
           sleep 30  # wait for servers to download layers
           sleep 30  # wait for servers to download layers
 
 
-          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID # ensure all servers survived init
+          kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
 
 
           pytest tests --durations=0 --durations-min=1.0 -v
           pytest tests --durations=0 --durations-min=1.0 -v
 
 

+ 2 - 0
src/petals/cli/run_server.py

@@ -146,6 +146,8 @@ def main():
                         help="Skip checking this server's reachability via health.petals.ml "
                         help="Skip checking this server's reachability via health.petals.ml "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "when connecting to the public swarm. If you connect to a private swarm, "
                              "the check is skipped by default. Use this option only if you know what you are doing")
                              "the check is skipped by default. Use this option only if you know what you are doing")
+    
+    parser.add_argument("--adapters", nargs='+', default=None, help="List of pretrained LoRA adapters that can be used for inference or training.")
 
 
     # fmt:on
     # fmt:on
     args = vars(parser.parse_args())
     args = vars(parser.parse_args())

+ 11 - 1
src/petals/server/server.py

@@ -81,6 +81,7 @@ class Server:
         dht_client_mode: Optional[bool] = None,
         dht_client_mode: Optional[bool] = None,
         use_relay: bool = True,
         use_relay: bool = True,
         use_auto_relay: bool = True,
         use_auto_relay: bool = True,
+        adapters: Optional[List[str]] = None,
         **kwargs,
         **kwargs,
     ):
     ):
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
         """Create a server with one or more bloom blocks. See run_server.py for documentation."""
@@ -217,6 +218,8 @@ class Server:
         self.balance_quality = balance_quality
         self.balance_quality = balance_quality
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_balance_check_period = mean_balance_check_period
         self.mean_block_selection_delay = mean_block_selection_delay
         self.mean_block_selection_delay = mean_block_selection_delay
+        
+        self.adapters = adapters
 
 
         self.stop = threading.Event()
         self.stop = threading.Event()
 
 
@@ -291,6 +294,7 @@ class Server:
                 quant_type=self.quant_type,
                 quant_type=self.quant_type,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 tensor_parallel_devices=self.tensor_parallel_devices,
                 should_validate_reachability=self.should_validate_reachability,
                 should_validate_reachability=self.should_validate_reachability,
+                adapters=self.adapters,
                 start=True,
                 start=True,
             )
             )
             try:
             try:
@@ -384,6 +388,7 @@ class ModuleContainer(threading.Thread):
         quant_type: QuantType,
         quant_type: QuantType,
         tensor_parallel_devices: Sequence[torch.device],
         tensor_parallel_devices: Sequence[torch.device],
         should_validate_reachability: bool,
         should_validate_reachability: bool,
+        adapters: Optional[List[str]] = None,
         **kwargs,
         **kwargs,
     ) -> ModuleContainer:
     ) -> ModuleContainer:
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
         module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
@@ -415,7 +420,12 @@ class ModuleContainer(threading.Thread):
                     cache_dir=cache_dir,
                     cache_dir=cache_dir,
                     max_disk_space=max_disk_space,
                     max_disk_space=max_disk_space,
                 )
                 )
-                block = convert_block(block, block_config, tensor_parallel_devices, device, quant_type, freeze=True)
+                block = convert_block(
+                    block, block_index, block_config, tensor_parallel_devices, device, quant_type, adapters=adapters, freeze=True,
+                    use_auth_token=use_auth_token,
+                    cache_dir=cache_dir,
+                    max_disk_space=max_disk_space,
+                )
                 blocks[module_uid] = TransformerBackend(
                 blocks[module_uid] = TransformerBackend(
                     module_uid,
                     module_uid,
                     block,
                     block,

+ 12 - 6
src/petals/utils/convert_block.py

@@ -4,13 +4,13 @@ Tools for converting transformer blocks, applying quantization and/or tensor par
 import os
 import os
 import re
 import re
 from enum import Enum
 from enum import Enum
-from typing import Sequence
+from typing import List, Optional, Sequence
 
 
 import tensor_parallel as tp
 import tensor_parallel as tp
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
 from hivemind.utils.logging import get_logger, use_hivemind_log_handler
-from peft import create_lora_adapter, add_adapter_to_block, load_peft
+from petals.utils.peft import create_lora_adapter, add_adapter_to_block, load_peft
 from tensor_parallel.slicing_configs import get_bloom_config
 from tensor_parallel.slicing_configs import get_bloom_config
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
@@ -26,12 +26,14 @@ class QuantType(Enum):
 
 
 def convert_block(
 def convert_block(
     block: nn.Module,
     block: nn.Module,
+    block_index: int,
     config: PretrainedConfig,
     config: PretrainedConfig,
     tensor_parallel_devices: Sequence[torch.device],
     tensor_parallel_devices: Sequence[torch.device],
     output_device: torch.device,
     output_device: torch.device,
     quant_type: QuantType,
     quant_type: QuantType,
     freeze: bool = True,
     freeze: bool = True,
     adapters: Optional[List[str]] = None,
     adapters: Optional[List[str]] = None,
+    **kwargs,
 ) -> tp.TensorParallel:
 ) -> tp.TensorParallel:
     """
     """
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
     Optimize a transformer block for use in a Petals server, apply tensor parallelism and/or LLM.8bit quantization
@@ -57,12 +59,16 @@ def convert_block(
 
 
     for shard, device in zip(block.module_shards, block.devices):
     for shard, device in zip(block.module_shards, block.devices):
         shard.to(device)
         shard.to(device)
-        
+
     if adapters:
     if adapters:
         create_lora_adapter(block)
         create_lora_adapter(block)
-        for adapter in adapters:
-            adapter_config, adapter_state_dict = load_peft(adapter)
-            add_adapter_to_block(block, adapter_config, adapter_state_dict)
+        for adapter_name in adapters:
+            adapter_config, adapter_state_dict = load_peft(
+                adapter_name,
+                block_idx=block_index,
+                **kwargs,
+            )
+            add_adapter_to_block(block, block_index, adapter_name, adapter_config, adapter_state_dict)
 
 
     return block
     return block
 
 

+ 58 - 24
src/petals/utils/peft.py

@@ -1,9 +1,14 @@
+import re
 import time
 import time
 from typing import List, Optional
 from typing import List, Optional
 
 
+import torch.nn as nn
+import bitsandbytes as bnb
+
 from hivemind.utils.logging import get_logger
 from hivemind.utils.logging import get_logger
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
 from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
-from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
+from peft.tuners import lora
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
 from safetensors import safe_open
 from safetensors import safe_open
 from safetensors.torch import load_file
 from safetensors.torch import load_file
 from transformers.utils import get_file_from_repo
 from transformers.utils import get_file_from_repo
@@ -19,23 +24,22 @@ def check_peft_repository(repo_id: str) -> bool:
     return len(list_of_files) > 0
     return len(list_of_files) > 0
 
 
 
 
-def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt", device: Optional[int] = None):
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
     tensors = dict()
     tensors = dict()
     is_tensors_found = dict()
     is_tensors_found = dict()
+    common_layer_patter_re = ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
     with safe_open(filepath, framework=framework, device=device) as f:
     with safe_open(filepath, framework=framework, device=device) as f:
         for k in f.keys():
         for k in f.keys():
-            for layer_name in layers_name:
-                if k.startswith(layer_name):
-                    is_tensors_found[layer_name] = True
-                    tensors[k] = f.get_tensor(k)
-        for layer_name in layers_name:
-            if not is_tensors_found.get(layer_name, False):
-                logger.warning(f"There is no peft weights with prefix {layer_name}")
+            if re.match(common_layer_patter_re, k):
+                is_tensors_found[block_idx] = True
+                tensors[k] = f.get_tensor(k)
+        if not is_tensors_found.get(block_idx, False):
+            logger.warning(f"There is no peft weights for block {block_idx}")
         return tensors
         return tensors
 
 
 
 
 def get_adapter_from_repo(
 def get_adapter_from_repo(
-    repo_id: str, layers_name: Optional[List[str]] = None, device: Optional[int] = None, **kwargs
+    repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs
 ):
 ):
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
     if config_path is None:
     if config_path is None:
@@ -45,14 +49,14 @@ def get_adapter_from_repo(
     weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
     weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
     if weight_path is None:
     if weight_path is None:
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
         raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
-    if layers_name is None:
+    if block_idx is None:
         return config, load_file(weight_path)
         return config, load_file(weight_path)
-    return config, load_specific_module(layers_name, weight_path, device=device)
+    return config, load_specific_module(block_idx, weight_path, device=device)
 
 
 
 
 def load_peft(
 def load_peft(
     repo_id: str,
     repo_id: str,
-    layers_name: Optional[List[str]] = None,
+    block_idx: Optional[int] = None,
     device: Optional[int] = None,
     device: Optional[int] = None,
     *,
     *,
     revision: Optional[str] = None,
     revision: Optional[str] = None,
@@ -70,7 +74,7 @@ def load_peft(
         with allow_cache_reads(cache_dir):
         with allow_cache_reads(cache_dir):
             return get_adapter_from_repo(
             return get_adapter_from_repo(
                 repo_id,
                 repo_id,
-                layers_name,
+                block_idx,
                 device,
                 device,
                 revision=revision,
                 revision=revision,
                 use_auth_token=use_auth_token,
                 use_auth_token=use_auth_token,
@@ -96,7 +100,7 @@ def load_peft(
 
 
                 return get_adapter_from_repo(
                 return get_adapter_from_repo(
                     repo_id,
                     repo_id,
-                    layers_name,
+                    block_idx,
                     device,
                     device,
                     revision=revision,
                     revision=revision,
                     use_auth_token=use_auth_token,
                     use_auth_token=use_auth_token,
@@ -115,8 +119,8 @@ def create_lora_adapter(block):
         for child_name, child in module.named_children():
         for child_name, child in module.named_children():
             lora_wrapped_child = None
             lora_wrapped_child = None
             if isinstance(child, nn.Linear):
             if isinstance(child, nn.Linear):
-                bias = hasattr(target, "bias") and target.bias is not None
-                lora_wrapped_child = peft.tuners.lora.Linear(
+                bias = hasattr(child, "bias") and child.bias is not None
+                lora_wrapped_child = lora.Linear(
                     child_name,
                     child_name,
                     child.in_features,
                     child.in_features,
                     child.out_features,
                     child.out_features,
@@ -128,9 +132,9 @@ def create_lora_adapter(block):
                     "memory_efficient_backward": child.state.memory_efficient_backward,
                     "memory_efficient_backward": child.state.memory_efficient_backward,
                     "threshold": child.state.threshold,
                     "threshold": child.state.threshold,
                     "index": child.index,
                     "index": child.index,
-                    "bias": hasattr(target, "bias") and target.bias is not None,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
                 }
                 }
-                lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
+                lora_wrapped_child = lora.Linear8bitLt(
                     child_name,
                     child_name,
                     child.in_features,
                     child.in_features,
                     child.out_features,
                     child.out_features,
@@ -141,9 +145,9 @@ def create_lora_adapter(block):
                     "compute_dtype": child.compute_dtype,
                     "compute_dtype": child.compute_dtype,
                     "compress_statistics": child.weight.compress_statistics,
                     "compress_statistics": child.weight.compress_statistics,
                     "quant_type": child.weight.quant_type,
                     "quant_type": child.weight.quant_type,
-                    "bias": hasattr(target, "bias") and target.bias is not None,
+                    "bias": hasattr(child, "bias") and child.bias is not None,
                 }
                 }
-                lora_wrapped_child = peft.tuners.lora.Linear4bit(
+                lora_wrapped_child = lora.Linear4bit(
                     child_name,
                     child_name,
                     child.in_features,
                     child.in_features,
                     child.out_features,
                     child.out_features,
@@ -151,9 +155,39 @@ def create_lora_adapter(block):
                 )
                 )
             if lora_wrapped_child:
             if lora_wrapped_child:
                 lora_wrapped_child.active_adapter = None
                 lora_wrapped_child.active_adapter = None
+                for p in lora_wrapped_child.parameters():
+                    p.requires_grad = False
                 setattr(module, child_name, lora_wrapped_child)
                 setattr(module, child_name, lora_wrapped_child)
                 
                 
                 
                 
-def add_adapter_to_block(block, peft_config, peft_state_dict):
-    assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
-    pass
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
+    assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
+    for name, module in block.named_modules():
+        for child_name, child in module.named_children():
+            if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
+                continue
+
+            if child_name in peft_config["target_modules"] or (isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name)):
+                is_lora_a_loaded = False
+                is_lora_b_loaded = False
+                for peft_key in peft_state_dict:
+                    if adapter_name not in child.lora_A:
+                        child.update_layer(
+                            adapter_name,
+                            peft_config["r"],
+                            peft_config["lora_alpha"],
+                            peft_config["lora_dropout"],
+                            peft_config["init_lora_weights"],
+                        )
+                        for p in child.parameters():
+                            p.requires_grad = False
+
+                    if "lora_A" in peft_key:
+                        child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] * child.scaling[adapter_name]
+                        is_lora_a_loaded = True
+                    elif "lora_B" in peft_key:
+                        child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
+                        is_lora_b_loaded = True
+                        
+                if is_lora_a_loaded and is_lora_b_loaded:
+                    logger.info(f"Loading {adapter_name} for block {block_index} is ended successfully")