|
@@ -0,0 +1,208 @@
|
|
|
+import re
|
|
|
+import time
|
|
|
+from typing import List, Optional
|
|
|
+
|
|
|
+import bitsandbytes as bnb
|
|
|
+import torch.nn as nn
|
|
|
+from hivemind.utils.logging import get_logger
|
|
|
+from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
|
|
|
+from peft.tuners import lora
|
|
|
+from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
|
|
|
+from safetensors import safe_open
|
|
|
+from safetensors.torch import load_file
|
|
|
+from transformers.utils import get_file_from_repo
|
|
|
+
|
|
|
+from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
|
+from petals.utils.misc import QuantType
|
|
|
+
|
|
|
+logger = get_logger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+def check_peft_repository(repo_id: str) -> bool:
|
|
|
+ fs = HfFileSystem()
|
|
|
+ list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
|
|
|
+ return len(list_of_files) > 0
|
|
|
+
|
|
|
+
|
|
|
+def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
|
|
|
+ tensors = dict()
|
|
|
+ is_tensors_found = dict()
|
|
|
+ common_layer_patter_re = (
|
|
|
+ ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"\.({block_idx})?\..+"
|
|
|
+ )
|
|
|
+ with safe_open(filepath, framework=framework, device=device) as f:
|
|
|
+ for k in f.keys():
|
|
|
+ if re.match(common_layer_patter_re, k):
|
|
|
+ is_tensors_found[block_idx] = True
|
|
|
+ tensors[k] = f.get_tensor(k)
|
|
|
+ if not is_tensors_found.get(block_idx, False):
|
|
|
+ logger.warning(f"There is no peft weights for block {block_idx}")
|
|
|
+ return tensors
|
|
|
+
|
|
|
+
|
|
|
+def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
|
|
|
+ config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
|
|
|
+ if config_path is None:
|
|
|
+ raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
|
|
|
+ config = PeftConfig.from_json_file(config_path)
|
|
|
+
|
|
|
+ weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
|
|
|
+ if weight_path is None:
|
|
|
+ raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
|
|
|
+ if block_idx is None:
|
|
|
+ return config, load_file(weight_path)
|
|
|
+ return config, load_specific_module(block_idx, weight_path, device=device)
|
|
|
+
|
|
|
+
|
|
|
+def load_peft(
|
|
|
+ repo_id: str,
|
|
|
+ block_idx: Optional[int] = None,
|
|
|
+ device: Optional[int] = None,
|
|
|
+ *,
|
|
|
+ revision: Optional[str] = None,
|
|
|
+ use_auth_token: Optional[str] = None,
|
|
|
+ cache_dir: str,
|
|
|
+ max_disk_space: Optional[int] = None,
|
|
|
+ delay: float = 30,
|
|
|
+):
|
|
|
+ # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
|
|
|
+
|
|
|
+ if not check_peft_repository(repo_id):
|
|
|
+ raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
|
|
|
+
|
|
|
+ try:
|
|
|
+ with allow_cache_reads(cache_dir):
|
|
|
+ return get_adapter_from_repo(
|
|
|
+ repo_id,
|
|
|
+ block_idx,
|
|
|
+ device,
|
|
|
+ revision=revision,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ local_files_only=False,
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
|
|
|
+
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ with allow_cache_writes(cache_dir):
|
|
|
+ config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
|
|
|
+ config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
|
|
|
+ weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
|
|
|
+ weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
|
|
|
+
|
|
|
+ file_size = config_file_size + weight_file_size
|
|
|
+ if file_size is not None:
|
|
|
+ free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
|
|
|
+ else:
|
|
|
+ logger.warning(f"Failed to fetch size from peft repo {repo_id}")
|
|
|
+
|
|
|
+ return get_adapter_from_repo(
|
|
|
+ repo_id,
|
|
|
+ block_idx,
|
|
|
+ device,
|
|
|
+ revision=revision,
|
|
|
+ use_auth_token=use_auth_token,
|
|
|
+ cache_dir=cache_dir,
|
|
|
+ local_files_only=False,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(
|
|
|
+ f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
|
|
+ )
|
|
|
+ time.sleep(delay)
|
|
|
+
|
|
|
+
|
|
|
+def create_lora_adapter(block, quant_type: QuantType):
|
|
|
+ for name, module in block.named_modules():
|
|
|
+ for child_name, child in module.named_children():
|
|
|
+ lora_wrapped_child = None
|
|
|
+ if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
|
|
+ continue
|
|
|
+ if quant_type == QuantType.INT8:
|
|
|
+ kwargs = {
|
|
|
+ "has_fp16_weights": False,
|
|
|
+ "threshold": 6.0,
|
|
|
+ "bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
+ }
|
|
|
+ lora_wrapped_child = lora.Linear8bitLt(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ elif quant_type == QuantType.NF4:
|
|
|
+ kwargs = {
|
|
|
+ "compress_statistics": True,
|
|
|
+ "quant_type": "nf4",
|
|
|
+ "blocksize": 64,
|
|
|
+ "bias": hasattr(child, "bias") and child.bias is not None,
|
|
|
+ }
|
|
|
+ lora_wrapped_child = lora.Linear4bit(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ bias = hasattr(child, "bias") and child.bias is not None
|
|
|
+ lora_wrapped_child = lora.Linear(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ bias=bias,
|
|
|
+ )
|
|
|
+ if lora_wrapped_child:
|
|
|
+ lora_wrapped_child.active_adapter = None
|
|
|
+ lora_wrapped_child.weight = child.weight
|
|
|
+ lora_wrapped_child.bias = child.bias
|
|
|
+ for p in lora_wrapped_child.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+ setattr(module, child_name, lora_wrapped_child)
|
|
|
+
|
|
|
+
|
|
|
+def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
|
|
+ assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
|
|
+ for name, module in block.named_modules():
|
|
|
+ for child_name, child in module.named_children():
|
|
|
+ if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
|
|
+ continue
|
|
|
+
|
|
|
+ if child_name in peft_config["target_modules"] or (
|
|
|
+ isinstance(peft_config["target_modules"], str)
|
|
|
+ and re.fullmatch(peft_config["target_modules"], child_name)
|
|
|
+ ):
|
|
|
+ is_lora_a_loaded = False
|
|
|
+ is_lora_b_loaded = False
|
|
|
+ for peft_key in peft_state_dict:
|
|
|
+ if peft_key.find(child_name) == -1:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if adapter_name not in child.lora_A:
|
|
|
+ child.update_layer(
|
|
|
+ adapter_name,
|
|
|
+ peft_config["r"],
|
|
|
+ peft_config["lora_alpha"],
|
|
|
+ lora_dropout=peft_config["lora_dropout"],
|
|
|
+ init_lora_weights=peft_config["init_lora_weights"],
|
|
|
+ )
|
|
|
+ child.train(False)
|
|
|
+ if peft_config["lora_dropout"] > 0:
|
|
|
+ logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
|
|
|
+ for p in child.parameters():
|
|
|
+ p.requires_grad = False
|
|
|
+
|
|
|
+ if peft_key.endswith(".lora_A.weight"):
|
|
|
+ child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
|
+ is_lora_a_loaded = True
|
|
|
+ elif peft_key.endswith(".lora_A.bias"):
|
|
|
+ raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
+ elif peft_key.endswith(".lora_B.weight"):
|
|
|
+ child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
|
|
|
+ is_lora_b_loaded = True
|
|
|
+ elif peft_key.endswith(".lora_B.bias"):
|
|
|
+ raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
+
|
|
|
+ if is_lora_a_loaded and is_lora_b_loaded:
|
|
|
+ logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|