justheuristic 3 年之前
父節點
當前提交
ac42a70ac8
共有 1 個文件被更改,包括 27 次插入42 次删除
  1. 27 42
      src/peft_utils.py

+ 27 - 42
src/peft_utils.py

@@ -1,8 +1,9 @@
 """
 """
-
 Generalized parameter-efficient finetuning module that supports deep prompts, bitfit, and several types of adapters.
 Generalized parameter-efficient finetuning module that supports deep prompts, bitfit, and several types of adapters.
 Designed to be used on both client and server side.
 Designed to be used on both client and server side.
 
 
+Note: if you want to fine-tune a model in a way that is not covered by this module, please implement the
+necessary parts on client side and keep the server-side code unchanged.
 """
 """
 from enum import Enum
 from enum import Enum
 from typing import Optional
 from typing import Optional
@@ -27,11 +28,13 @@ class TransformerBlockPEFT(nn.Module):
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
 
 
         # "deep" prompts, applied to the outputs of each layer (https://arxiv.org/abs/2110.07602)
         # "deep" prompts, applied to the outputs of each layer (https://arxiv.org/abs/2110.07602)
-        self.prompts = nn.Parameter(DUMMY)   # dummy or [batch_size or 1, seq_length_prefix, hid_size]
+        self.output_prompts = nn.Parameter(DUMMY)   # dummy or [batch_size or 1, seq_length_prefix, hid_size]
 
 
-        # adapter input projection; used for output adapters, can be reused for other adapters
-        self.key_adapter = LowRankAdapter
-        self.adapter_in_bias = nn.Parameter(DUMMY)  # [hid_size]
+        self.attention_query_adapter = GenericAdapter(self.hidden_size)
+        self.attention_key_adapter = GenericAdapter(self.hidden_size)
+        self.attention_value_adapter = GenericAdapter(self.hidden_size)
+        self.attention_out_adapter = GenericAdapter(self.hidden_size)
+        self.mlp_in_adapter = GenericAdapter(self.hidden_size)
 
 
         # output projection, applied to the residual layer after MLP
         # output projection, applied to the residual layer after MLP
         self.adapter_out_weight = nn.Parameter(DUMMY)  # [adapter_dim, hid_size or hid_size * 2]
         self.adapter_out_weight = nn.Parameter(DUMMY)  # [adapter_dim, hid_size or hid_size * 2]
@@ -40,7 +43,6 @@ class TransformerBlockPEFT(nn.Module):
 
 
 # planned:
 # planned:
 # strategy: define
 # strategy: define
-# - remove the part that stacks multiplicative and additive adapter weights - it does not help!
 # - check that LowRankAdapter works :)
 # - check that LowRankAdapter works :)
 # - implement a function that converts lowrank adapter to [list_of_tensors, metadata]
 # - implement a function that converts lowrank adapter to [list_of_tensors, metadata]
 # - pass list of tensors and metadata in chained requests
 # - pass list of tensors and metadata in chained requests
@@ -48,15 +50,16 @@ class TransformerBlockPEFT(nn.Module):
 # - check exact match with local layer
 # - check exact match with local layer
 
 
 
 
-class LowRankAdapter(nn.Module):
-    def __init__(self, hidden_size: int):
+class GenericAdapter(nn.Module):
+    def __init__(self, in_features: int, out_features: int):
         super().__init__()
         super().__init__()
-        self.hidden_size = hidden_size
-        self.in_proj = nn.Parameter(DUMMY, requires_grad=False)     # [rank, hid_size]
-        self.hid_bias = nn.Parameter(DUMMY, requires_grad=False)    # [rank]
-        self.out_proj = nn.Parameter(DUMMY, requires_grad=False)    # [hid_size or 2 * hid_size, rank]
-        self.out_scale = nn.Parameter(DUMMY, requires_grad=False)   # [hid_size]
-        self.out_bias = nn.Parameter(DUMMY, requires_grad=False)    # [hid_size]
+        self.in_features, self.out_features = in_features, out_features
+        self.in_proj = nn.Parameter(DUMMY, requires_grad=False)         # [rank, in_features]
+        self.hid_bias = nn.Parameter(DUMMY, requires_grad=False)        # [rank]
+        self.out_proj = nn.Parameter(DUMMY, requires_grad=False)        # [out_features, rank]
+        self.out_bias = nn.Parameter(DUMMY, requires_grad=False)        # [out_features]
+        self.out_scale_proj = nn.Parameter(DUMMY, requires_grad=False)  # [out_features, rank]
+        self.out_scale = nn.Parameter(DUMMY, requires_grad=False)       # [out_features]
         self.register_buffer("activation", torch.tensor(0, torch.int64), persistent=True)  # []
         self.register_buffer("activation", torch.tensor(0, torch.int64), persistent=True)  # []
 
 
     def forward(self, input: torch.Tensor, base_output: Optional[torch.Tensor] = None) -> torch.Tensor:
     def forward(self, input: torch.Tensor, base_output: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -66,40 +69,22 @@ class LowRankAdapter(nn.Module):
         :return: adjusted output, after using the low-rank adapter
         :return: adjusted output, after using the low-rank adapter
         """
         """
         base_output = base_output if base_output is not None else input
         base_output = base_output if base_output is not None else input
+        dtype, device = input.dtype, input.device
         has_scale, has_bias = not is_dummy(self.out_scale), not is_dummy(self.out_bias)
         has_scale, has_bias = not is_dummy(self.out_scale), not is_dummy(self.out_bias)
         has_adapter = not is_dummy(self.in_proj)
         has_adapter = not is_dummy(self.in_proj)
 
 
         # adapter components
         # adapter components
-        additive = self.out_bias if has_bias else None
-        multiplicative = self.out_scale if has_scale else None
+        additive = self.out_bias if has_bias else torch.zeros(self.out_features, dtype=dtype, device=device)
+        multiplicative = self.out_scale if has_scale else torch.ones(self.out_features, dtype=dtype, device=device)
 
 
         if has_adapter:
         if has_adapter:
             hid = F.linear(input, weight=self.in_proj, bias=None if is_dummy(self.in_bias) else self.in_bias)
             hid = F.linear(input, weight=self.in_proj, bias=None if is_dummy(self.in_bias) else self.in_bias)
-
-            if self.activation:
-                activation_fn = _ACTIVATIONS_BY_INDEX[int(self.activation.item())]
-                hid = activation_fn(hid)
-
-            if self.out_proj.shape[0] == self.hidden_size:
+            hid = _ACTIVATIONS_BY_INDEX[int(self.activation.item())](hid)
+            if not is_dummy(self.out_proj):
                 additive = F.linear(hid, self.out_proj, bias=additive)
                 additive = F.linear(hid, self.out_proj, bias=additive)
-
-            elif self.out_proj.shape[0] == 2 * self.hidden_size:
-                bias_and_scale = None
-                if has_scale or has_bias:
-                    scale_or_ones = self.out_scale if has_scale else torch.ones_like(self.out_bias)
-                    bias_or_zeros = self.out_bias if has_bias else torch.zeros_like(self.out_scale)
-                    bias_and_scale = torch.cat([bias_or_zeros, scale_or_ones], dim=0)
-                combined_out = F.linear(hid, self.out_proj, bias=bias_and_scale)
-                additive, multiplicative = combined_out.split(self.hidden_size, dim=-1)
-
-        if additive is not None and multiplicative is not None:
-            return torch.addcmul(additive, base_output, multiplicative)
-        elif additive is not None:
-            return additive.add_(base_output)
-        elif multiplicative is not None:
-            return base_output * multiplicative
-        else:
-            return base_output
+            if not is_dummy(self.out_scale_proj):
+                multiplicative = F.linear(hid, self.out_scale_proj, bias=multiplicative)
+        return torch.addcmul(additive, base_output, multiplicative)
 
 
     @property
     @property
     def rank(self) -> int:
     def rank(self) -> int:
@@ -109,12 +94,12 @@ class LowRankAdapter(nn.Module):
 class ACTIVATIONS(Enum):
 class ACTIVATIONS(Enum):
     # enum of allowed activations for server-side adapters; linear activation is represented with DUMMY tensor
     # enum of allowed activations for server-side adapters; linear activation is represented with DUMMY tensor
     # beware: these activations should be backwards compatible! new activations can only be added to the end of the list
     # beware: these activations should be backwards compatible! new activations can only be added to the end of the list
-    relu, gelu, relu6, leaky_relu, sigmoid, tanh = range(1, 7)
+    linear, relu, gelu, relu6, leaky_relu, sigmoid, tanh = range(7)
 
 
 
 
 for act in list(ACTIVATIONS)[1:]:
 for act in list(ACTIVATIONS)[1:]:
     assert hasattr(F, act.name), act.name
     assert hasattr(F, act.name), act.name
 
 
-_ACTIVATIONS_BY_INDEX = {act.value: getattr(F, act.name) for act in ACTIVATIONS}
+_ACTIVATIONS_BY_INDEX = {act.value: getattr(F, act.name) for act in list(ACTIVATIONS)[1:]}
 _ACTIVATIONS_BY_INDEX[0] = lambda x: x
 _ACTIVATIONS_BY_INDEX[0] = lambda x: x