|
@@ -1,12 +1,8 @@
|
|
|
-import os
|
|
|
import re
|
|
|
import time
|
|
|
-from typing import List, Optional, Sequence
|
|
|
-
|
|
|
-os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
|
|
+from typing import Optional, Sequence
|
|
|
|
|
|
import bitsandbytes as bnb
|
|
|
-import peft
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import transformers
|
|
@@ -19,7 +15,6 @@ from safetensors import safe_open
|
|
|
from safetensors.torch import load_file
|
|
|
from transformers.utils import get_file_from_repo
|
|
|
|
|
|
-from petals.client.ptune import force_non_empty_weights
|
|
|
from petals.server.block_utils import resolve_block_dtype
|
|
|
from petals.utils.disk_cache import allow_cache_reads, allow_cache_writes, free_disk_space_for
|
|
|
from petals.utils.misc import QuantType
|
|
@@ -124,7 +119,7 @@ def load_peft(
|
|
|
|
|
|
|
|
|
def create_lora_adapter(block, quant_type: QuantType):
|
|
|
- for name, module in block.named_modules():
|
|
|
+ for _, module in block.named_modules():
|
|
|
for child_name, child in module.named_children():
|
|
|
lora_wrapped_child = None
|
|
|
if not isinstance(child, (nn.Linear, bnb.nn.Linear8bitLt, bnb.nn.Linear4bit)):
|
|
@@ -173,7 +168,10 @@ def create_lora_adapter(block, quant_type: QuantType):
|
|
|
|
|
|
def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_state_dict):
|
|
|
assert peft_config["peft_type"] == "LORA", "Petals works only with LORA adapters"
|
|
|
- for name, module in block.named_modules():
|
|
|
+ if peft_config["lora_dropout"] > 0:
|
|
|
+ logger.info(f"Adapter {adapter_name} has dropout enabled, this server will disable dropout")
|
|
|
+
|
|
|
+ for _, module in block.named_modules():
|
|
|
for child_name, child in module.named_children():
|
|
|
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
|
|
|
continue
|
|
@@ -185,7 +183,7 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
|
|
|
is_lora_a_loaded = False
|
|
|
is_lora_b_loaded = False
|
|
|
for peft_key in peft_state_dict:
|
|
|
- if peft_key.find(child_name) == -1:
|
|
|
+ if child_name not in peft_key:
|
|
|
continue
|
|
|
|
|
|
if adapter_name not in child.lora_A:
|
|
@@ -197,8 +195,6 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
|
|
|
init_lora_weights=peft_config["init_lora_weights"],
|
|
|
)
|
|
|
child.train(False)
|
|
|
- if peft_config["lora_dropout"] > 0:
|
|
|
- logger.warning("Loading LoRA config with dropout enabled; this server will disable dropout")
|
|
|
for p in child.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
@@ -214,7 +210,10 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
|
|
|
raise NotImplementedError(f"LoRA adapters with bias not supported: {peft_key}")
|
|
|
|
|
|
if is_lora_a_loaded and is_lora_b_loaded:
|
|
|
- logger.info(f"Loading {adapter_name} for block {block_index}.{child_name} is ended successfully")
|
|
|
+ logger.debug(f"Loaded adapter {adapter_name} for block {block_index}.{child_name}")
|
|
|
+ elif is_lora_a_loaded or is_lora_b_loaded:
|
|
|
+ raise ValueError(f"Invalid adapter {adapter_name} for block {block_index}.{child_name}")
|
|
|
+ logger.info(f"Loaded adapter {adapter_name} for block {block_index}")
|
|
|
|
|
|
|
|
|
def estimate_adapter_memory_per_block(
|