|
@@ -108,3 +108,52 @@ def load_peft(
|
|
|
f"Failed to load peft weights {repo_id} from HF Hub (retry in {delay:.0f} sec)", exc_info=True
|
|
|
)
|
|
|
time.sleep(delay)
|
|
|
+
|
|
|
+
|
|
|
+def create_lora_adapter(block):
|
|
|
+ for name, module in block.named_modules():
|
|
|
+ for child_name, child in module.named_children():
|
|
|
+ lora_wrapped_child = None
|
|
|
+ if isinstance(child, nn.Linear):
|
|
|
+ bias = hasattr(target, "bias") and target.bias is not None
|
|
|
+ lora_wrapped_child = peft.tuners.lora.Linear(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ bias=bias,
|
|
|
+ )
|
|
|
+ elif isinstance(child, bnb.nn.Linear8bitLt):
|
|
|
+ kwargs = {
|
|
|
+ "has_fp16_weights": child.state.has_fp16_weights,
|
|
|
+ "memory_efficient_backward": child.state.memory_efficient_backward,
|
|
|
+ "threshold": child.state.threshold,
|
|
|
+ "index": child.index,
|
|
|
+ "bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
+ }
|
|
|
+ lora_wrapped_child = peft.tuners.lora.Linear8bitLt(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ elif isinstance(child, bnb.nn.Linear4bit):
|
|
|
+ kwargs = {
|
|
|
+ "compute_dtype": child.compute_dtype,
|
|
|
+ "compress_statistics": child.weight.compress_statistics,
|
|
|
+ "quant_type": child.weight.quant_type,
|
|
|
+ "bias": hasattr(target, "bias") and target.bias is not None,
|
|
|
+ }
|
|
|
+ lora_wrapped_child = peft.tuners.lora.Linear4bit(
|
|
|
+ child_name,
|
|
|
+ child.in_features,
|
|
|
+ child.out_features,
|
|
|
+ **kwargs,
|
|
|
+ )
|
|
|
+ if lora_wrapped_child:
|
|
|
+ lora_wrapped_child.active_adapter = None
|
|
|
+ setattr(module, child_name, lora_wrapped_child)
|
|
|
+
|
|
|
+
|
|
|
+def add_adapter_to_block(block, peft_config, peft_state_dict):
|
|
|
+ assert peft_config.peft_type == peft.PeftType.LORA, "Petals works only with LORA adapters"
|
|
|
+ pass
|