|
@@ -1,9 +1,14 @@
|
|
|
+import re
|
|
|
import time
|
|
|
from typing import List, Optional
|
|
|
|
|
|
+import torch.nn as nn
|
|
|
+import bitsandbytes as bnb
|
|
|
+
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
-from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
|
+from peft.tuners import lora
|
|
|
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
|
from safetensors import safe_open
|
|
|
from safetensors.torch import load_file
|
|
|
from transformers.utils import get_file_from_repo
|
|
@@ -19,23 +24,22 @@ def check_peft_repository(repo_id: str) -> bool:
|
|
|
return len(list_of_files) > 0
|
|
|
|
|
|
|
|
|
-def load_specific_module(layers_name: List[str], filepath: str, framework: str = "pt", device: Optional[int] = None):
|
|
|
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
|
|
tensors = dict()
|
|
|
is_tensors_found = dict()
|
|
|
+ common_layer_patter_re = ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
|
|
|
with safe_open(filepath, framework=framework, device=device) as f:
|
|
|
for k in f.keys():
|
|
|
- for layer_name in layers_name:
|
|
|
- if k.startswith(layer_name):
|
|
|
- is_tensors_found[layer_name] = True
|
|
|
- tensors[k] = f.get_tensor(k)
|
|
|
- for layer_name in layers_name:
|
|
|
- if not is_tensors_found.get(layer_name, False):
|
|
|
- logger.warning(f"There is no peft weights with prefix {layer_name}")
|
|
|
+ if re.match(common_layer_patter_re, k):
|
|
|
+ is_tensors_found[block_idx] = True
|
|
|
+ tensors[k] = f.get_tensor(k)
|
|
|
+ if not is_tensors_found.get(block_idx, False):
|
|
|
+ logger.warning(f"There is no peft weights for block {block_idx}")
|
|
|
return tensors
|
|
|
|
|
|
|
|
|
def get_adapter_from_repo(
|
|
|
- repo_id: str, layers_name: Optional[List[str]] = None, device: Optional[int] = None, **kwargs
|
|
|
+ repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs
|
|
|
):
|
|
|
config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
|
if config_path is None:
|
|
@@ -45,14 +49,14 @@ def get_adapter_from_repo(
|
|
|
weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
|
|
if weight_path is None:
|
|
|
raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
- if layers_name is None:
|
|
|
+ if block_idx is None:
|
|
|
return config, load_file(weight_path)
|
|
|
- return config, load_specific_module(layers_name, weight_path, device=device)
|
|
|
+ return config, load_specific_module(block_idx, weight_path, device=device)
|
|
|
|
|
|
|
|
|
def load_peft(
|
|
|
repo_id: str,
|
|
|
- layers_name: Optional[List[str]] = None,
|
|
|
+ block_idx: Optional[int] = None,
|
|
|
device: Optional[int] = None,
|
|
|
*,
|
|
|
revision: Optional[str] = None,
|
|
@@ -70,7 +74,7 @@ def load_peft(
|
|
|
with allow_cache_reads(cache_dir):
|
|
|
return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
- layers_name,
|
|
|
+ block_idx,
|
|
|
device,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
@@ -96,7 +100,7 @@ def load_peft(
|
|
|
|
|
|
return get_adapter_from_repo(
|
|
|
repo_id,
|
|
|
- layers_name,
|
|
|
+ block_idx,
|
|
|
device,
|
|
|
revision=revision,
|
|
|
use_auth_token=use_auth_token,
|
|
@@ -115,8 +119,8 @@ def create_lora_adapter(block):
|
|
|
for child_name, child in module.named_children():
|
|
|
lora_wrapped_child = None
|
|
|
if isinstance(child, nn.Linear):
|
|
|
- bias = hasattr(target, "bias") and target.bias is not None
|
|
|
- lora_wrapped_child = peft.tuners.lora.Linear(
|
|
|
+ bias = hasattr(child, "bias") and child.bias is not None
|
|
|
+ lora_wrapped_child = lora.Linear(
|
|
|
child_name,
|
|
|
child.in_features,
|
|
|
child.out_features,
|
|
@@ -128,9 +132,9 @@ def create_lora_adapter(block):
|
|
|
"memory_efficient_backward": child.state.memory_efficient_backward,
|
|
|
"threshold": child.state.threshold,
|
|
|
"index": child.index,
|
|
|
- "bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
+ "bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
}
|
|
|
- lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
|
|
|
+ lora_wrapped_child = lora.Linear8bitLt(
|
|
|
child_name,
|
|
|
child.in_features,
|
|
|
child.out_features,
|
|
@@ -141,9 +145,9 @@ def create_lora_adapter(block):
|
|
|
"compute_dtype": child.compute_dtype,
|
|
|
"compress_statistics": child.weight.compress_statistics,
|
|
|
"quant_type": child.weight.quant_type,
|
|
|
- "bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
+ "bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
}
|
|
|
- lora_wrapped_child = peft.tuners.lora.Linear4bit(
|
|
|
+ lora_wrapped_child = lora.Linear4bit(
|
|
|
child_name,
|
|
|
child.in_features,
|
|
|
child.out_features,
|
|
@@ -151,9 +155,39 @@ def create_lora_adapter(block):
|
|
|
)
|
|
|
if lora_wrapped_child:
|
|
|
lora_wrapped_child.active_adapter = None
|
|
|
+ for p in lora_wrapped_child.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
setattr(module, child_name, lora_wrapped_child)
|
|
|
|
|
|
|
|
|
-def add_adapter_to_block(block, peft_config, peft_state_dict):
|
|
|
- assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
|
|
|
- pass
|
|
|
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
|
|
+ assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
|
|
+ for name, module in block.named_modules():
|
|
|
+ for child_name, child in module.named_children():
|
|
|
+ if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
|
|
+ continue
|
|
|
+
|
|
|
+ if child_name in peft_config["target_modules"] or (isinstance(peft_config["target_modules"], str) and re.fullmatch(peft_config["target_modules"], child_name)):
|
|
|
+ is_lora_a_loaded = False
|
|
|
+ is_lora_b_loaded = False
|
|
|
+ for peft_key in peft_state_dict:
|
|
|
+ if adapter_name not in child.lora_A:
|
|
|
+ child.update_layer(
|
|
|
+ adapter_name,
|
|
|
+ peft_config["r"],
|
|
|
+ peft_config["lora_alpha"],
|
|
|
+ peft_config["lora_dropout"],
|
|
|
+ peft_config["init_lora_weights"],
|
|
|
+ )
|
|
|
+ for p in child.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ if "lora_A" in peft_key:
|
|
|
+ child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] * child.scaling[adapter_name]
|
|
|
+ is_lora_a_loaded = True
|
|
|
+ elif "lora_B" in peft_key:
|
|
|
+ child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
|
+ is_lora_b_loaded = True
|
|
|
+
|
|
|
+ if is_lora_a_loaded and is_lora_b_loaded:
|
|
|
+ logger.info(f"Loading {adapter_name} for block {block_index} is ended successfully")
|