peft.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import re
  2. import time
  3. from typing import List, Optional
  4. import bitsandbytes as bnb
  5. import torch.nn as nn
  6. from hivemind.utils.logging import get_logger
  7. from huggingface_hub import HfFileSystem, get_hf_file_metadata, hf_hub_url
  8. from peft.tuners import lora
  9. from peft.utils import COMMON_LAYERS_PATTERN, CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME, PeftConfig
  10. from safetensors import safe_open
  11. from safetensors.torch import load_file
  12. from transformers.utils import get_file_from_repo
  13. from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
  14. logger = get_logger(__name__)
  15. def check_peft_repository(repo_id: str) -> bool:
  16. fs = HfFileSystem()
  17. list_of_files = fs.glob(f"{repo_id}/{SAFETENSORS_WEIGHTS_NAME}", detail=False)
  18. return len(list_of_files) > 0
  19. def load_specific_module(block_idx: int, filepath: str, framework: str = "pt", device: Optional[int] = None):
  20. tensors = dict()
  21. is_tensors_found = dict()
  22. common_layer_patter_re = (
  23. ".+\." + "".join(f"({common_name})?" for common_name in COMMON_LAYERS_PATTERN) + f"({block_idx})?\.0\..+"
  24. )
  25. with safe_open(filepath, framework=framework, device=device) as f:
  26. for k in f.keys():
  27. if re.match(common_layer_patter_re, k):
  28. is_tensors_found[block_idx] = True
  29. tensors[k] = f.get_tensor(k)
  30. if not is_tensors_found.get(block_idx, False):
  31. logger.warning(f"There is no peft weights for block {block_idx}")
  32. return tensors
  33. def get_adapter_from_repo(repo_id: str, block_idx: Optional[int] = None, device: Optional[int] = None, **kwargs):
  34. config_path = get_file_from_repo(repo_id, CONFIG_NAME, **kwargs)
  35. if config_path is None:
  36. raise RuntimeError(f"File {CONFIG_NAME} does not exist in repo {repo_id}")
  37. config = PeftConfig.from_json_file(config_path)
  38. weight_path = get_file_from_repo(repo_id, SAFETENSORS_WEIGHTS_NAME, **kwargs)
  39. if weight_path is None:
  40. raise RuntimeError(f"File {SAFETENSORS_WEIGHTS_NAME} does not exist in repo {repo_id}")
  41. if block_idx is None:
  42. return config, load_file(weight_path)
  43. return config, load_specific_module(block_idx, weight_path, device=device)
  44. def load_peft(
  45. repo_id: str,
  46. block_idx: Optional[int] = None,
  47. device: Optional[int] = None,
  48. *,
  49. revision: Optional[str] = None,
  50. use_auth_token: Optional[str] = None,
  51. cache_dir: str,
  52. max_disk_space: Optional[int] = None,
  53. delay: float = 30,
  54. ):
  55. # TODO: Check is it possible to add safetensors loading inside petals/server/from_pretrained.py and reuse it here
  56. if not check_peft_repository(repo_id):
  57. raise ValueError(f"Repo: {repo_id} doesn't have safetensors inside for a safe loading.")
  58. try:
  59. with allow_cache_reads(cache_dir):
  60. return get_adapter_from_repo(
  61. repo_id,
  62. block_idx,
  63. device,
  64. revision=revision,
  65. use_auth_token=use_auth_token,
  66. cache_dir=cache_dir,
  67. local_files_only=False,
  68. )
  69. except Exception:
  70. logger.warning(f"Cache for peft weights {repo_id} is corrupted, it will be downloaded again", exc_info=True)
  71. while True:
  72. try:
  73. with allow_cache_writes(cache_dir):
  74. config_url = hf_hub_url(repo_id, CONFIG_NAME, revision=revision)
  75. config_file_size = get_hf_file_metadata(config_url, token=use_auth_token).size
  76. weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
  77. weight_file_size = get_hf_file_metadata(weight_url, token=use_auth_token).size
  78. file_size = config_file_size + weight_file_size
  79. if file_size is not None:
  80. free_disk_space_for(file_size, cache_dir=cache_dir, max_disk_space=max_disk_space)
  81. else:
  82. logger.warning(f"Failed to fetch size from peft repo {repo_id}")
  83. return get_adapter_from_repo(
  84. repo_id,
  85. block_idx,
  86. device,
  87. revision=revision,
  88. use_auth_token=use_auth_token,
  89. cache_dir=cache_dir,
  90. local_files_only=False,
  91. )
  92. except Exception as e:
  93. logger.warning(
  94. f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
  95. )
  96. time.sleep(delay)
  97. def create_lora_adapter(block):
  98. for name, module in block.named_modules():
  99. for child_name, child in module.named_children():
  100. lora_wrapped_child = None
  101. if isinstance(child, nn.Linear):
  102. bias = hasattr(child, "bias") and child.bias is not None
  103. lora_wrapped_child = lora.Linear(
  104. child_name,
  105. child.in_features,
  106. child.out_features,
  107. bias=bias,
  108. )
  109. elif isinstance(child, bnb.nn.Linear8bitLt):
  110. kwargs = {
  111. "has_fp16_weights": child.state.has_fp16_weights,
  112. "memory_efficient_backward": child.state.memory_efficient_backward,
  113. "threshold": child.state.threshold,
  114. "index": child.index,
  115. "bias": hasattr(child, "bias") and child.bias is not None,
  116. }
  117. lora_wrapped_child = lora.Linear8bitLt(
  118. child_name,
  119. child.in_features,
  120. child.out_features,
  121. **kwargs,
  122. )
  123. elif isinstance(child, bnb.nn.Linear4bit):
  124. kwargs = {
  125. "compute_dtype": child.compute_dtype,
  126. "compress_statistics": child.weight.compress_statistics,
  127. "quant_type": child.weight.quant_type,
  128. "bias": hasattr(child, "bias") and child.bias is not None,
  129. }
  130. lora_wrapped_child = lora.Linear4bit(
  131. child_name,
  132. child.in_features,
  133. child.out_features,
  134. **kwargs,
  135. )
  136. if lora_wrapped_child:
  137. lora_wrapped_child.active_adapter = None
  138. lora_wrapped_child.weight = child.weight
  139. lora_wrapped_child.bias = child.bias
  140. for p in lora_wrapped_child.parameters():
  141. p.requires_grad = False
  142. setattr(module, child_name, lora_wrapped_child)
  143. def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
  144. assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
  145. for name, module in block.named_modules():
  146. for child_name, child in module.named_children():
  147. if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
  148. continue
  149. if child_name in peft_config["target_modules"] or (
  150. isinstance(peft_config["target_modules"], str)
  151. and re.fullmatch(peft_config["target_modules"], child_name)
  152. ):
  153. is_lora_a_loaded = False
  154. is_lora_b_loaded = False
  155. for peft_key in peft_state_dict:
  156. if adapter_name not in child.lora_A:
  157. child.update_layer(
  158. adapter_name,
  159. peft_config["r"],
  160. peft_config["lora_alpha"],
  161. peft_config["lora_dropout"],
  162. peft_config["init_lora_weights"],
  163. )
  164. for p in child.parameters():
  165. p.requires_grad = False
  166. if "lora_A" in peft_key:
  167. child.lora_A[adapter_name].weight.data = peft_state_dict[peft_key] * child.scaling[adapter_name]
  168. is_lora_a_loaded = True
  169. elif "lora_B" in peft_key:
  170. child.lora_B[adapter_name].weight.data = peft_state_dict[peft_key]
  171. is_lora_b_loaded = True
  172. if is_lora_a_loaded and is_lora_b_loaded:
  173. logger.info(f"Loading {adapter_name} for block {block_index} is ended successfully")