justheuristic 3 năm trước cách đây
mục cha
commit
ac42a70ac8
1 tập tin đã thay đổi với 27 bổ sung42 xóa
  1. 27 42
      src/peft_utils.py

+ 27 - 42
src/peft_utils.py

@@ -1,8 +1,9 @@
 """
-
 Generalized parameter-efficient finetuning module that supports deep prompts, bitfit, and several types of adapters.
 Designed to be used on both client and server side.
 
+Note: if you want to fine-tune a model in a way that is not covered by this module, please implement the
+necessary parts on client side and keep the server-side code unchanged.
 """
 from enum import Enum
 from typing import Optional
@@ -27,11 +28,13 @@ class TransformerBlockPEFT(nn.Module):
         self.hidden_size = hidden_size
 
         # "deep" prompts, applied to the outputs of each layer (https://arxiv.org/abs/2110.07602)
-        self.prompts = nn.Parameter(DUMMY)   # dummy or [batch_size or 1, seq_length_prefix, hid_size]
+        self.output_prompts = nn.Parameter(DUMMY)   # dummy or [batch_size or 1, seq_length_prefix, hid_size]
 
-        # adapter input projection; used for output adapters, can be reused for other adapters
-        self.key_adapter = LowRankAdapter
-        self.adapter_in_bias = nn.Parameter(DUMMY)  # [hid_size]
+        self.attention_query_adapter = GenericAdapter(self.hidden_size)
+        self.attention_key_adapter = GenericAdapter(self.hidden_size)
+        self.attention_value_adapter = GenericAdapter(self.hidden_size)
+        self.attention_out_adapter = GenericAdapter(self.hidden_size)
+        self.mlp_in_adapter = GenericAdapter(self.hidden_size)
 
         # output projection, applied to the residual layer after MLP
         self.adapter_out_weight = nn.Parameter(DUMMY)  # [adapter_dim, hid_size or hid_size * 2]
@@ -40,7 +43,6 @@ class TransformerBlockPEFT(nn.Module):
 
 # planned:
 # strategy: define
-# - remove the part that stacks multiplicative and additive adapter weights - it does not help!
 # - check that LowRankAdapter works :)
 # - implement a function that converts lowrank adapter to [list_of_tensors, metadata]
 # - pass list of tensors and metadata in chained requests
@@ -48,15 +50,16 @@ class TransformerBlockPEFT(nn.Module):
 # - check exact match with local layer
 
 
-class LowRankAdapter(nn.Module):
-    def __init__(self, hidden_size: int):
+class GenericAdapter(nn.Module):
+    def __init__(self, in_features: int, out_features: int):
         super().__init__()
-        self.hidden_size = hidden_size
-        self.in_proj = nn.Parameter(DUMMY, requires_grad=False)     # [rank, hid_size]
-        self.hid_bias = nn.Parameter(DUMMY, requires_grad=False)    # [rank]
-        self.out_proj = nn.Parameter(DUMMY, requires_grad=False)    # [hid_size or 2 * hid_size, rank]
-        self.out_scale = nn.Parameter(DUMMY, requires_grad=False)   # [hid_size]
-        self.out_bias = nn.Parameter(DUMMY, requires_grad=False)    # [hid_size]
+        self.in_features, self.out_features = in_features, out_features
+        self.in_proj = nn.Parameter(DUMMY, requires_grad=False)         # [rank, in_features]
+        self.hid_bias = nn.Parameter(DUMMY, requires_grad=False)        # [rank]
+        self.out_proj = nn.Parameter(DUMMY, requires_grad=False)        # [out_features, rank]
+        self.out_bias = nn.Parameter(DUMMY, requires_grad=False)        # [out_features]
+        self.out_scale_proj = nn.Parameter(DUMMY, requires_grad=False)  # [out_features, rank]
+        self.out_scale = nn.Parameter(DUMMY, requires_grad=False)       # [out_features]
         self.register_buffer("activation", torch.tensor(0, torch.int64), persistent=True)  # []
 
     def forward(self, input: torch.Tensor, base_output: Optional[torch.Tensor] = None) -> torch.Tensor:
@@ -66,40 +69,22 @@ class LowRankAdapter(nn.Module):
         :return: adjusted output, after using the low-rank adapter
         """
         base_output = base_output if base_output is not None else input
+        dtype, device = input.dtype, input.device
         has_scale, has_bias = not is_dummy(self.out_scale), not is_dummy(self.out_bias)
         has_adapter = not is_dummy(self.in_proj)
 
         # adapter components
-        additive = self.out_bias if has_bias else None
-        multiplicative = self.out_scale if has_scale else None
+        additive = self.out_bias if has_bias else torch.zeros(self.out_features, dtype=dtype, device=device)
+        multiplicative = self.out_scale if has_scale else torch.ones(self.out_features, dtype=dtype, device=device)
 
         if has_adapter:
             hid = F.linear(input, weight=self.in_proj, bias=None if is_dummy(self.in_bias) else self.in_bias)
-
-            if self.activation:
-                activation_fn = _ACTIVATIONS_BY_INDEX[int(self.activation.item())]
-                hid = activation_fn(hid)
-
-            if self.out_proj.shape[0] == self.hidden_size:
+            hid = _ACTIVATIONS_BY_INDEX[int(self.activation.item())](hid)
+            if not is_dummy(self.out_proj):
                 additive = F.linear(hid, self.out_proj, bias=additive)
-
-            elif self.out_proj.shape[0] == 2 * self.hidden_size:
-                bias_and_scale = None
-                if has_scale or has_bias:
-                    scale_or_ones = self.out_scale if has_scale else torch.ones_like(self.out_bias)
-                    bias_or_zeros = self.out_bias if has_bias else torch.zeros_like(self.out_scale)
-                    bias_and_scale = torch.cat([bias_or_zeros, scale_or_ones], dim=0)
-                combined_out = F.linear(hid, self.out_proj, bias=bias_and_scale)
-                additive, multiplicative = combined_out.split(self.hidden_size, dim=-1)
-
-        if additive is not None and multiplicative is not None:
-            return torch.addcmul(additive, base_output, multiplicative)
-        elif additive is not None:
-            return additive.add_(base_output)
-        elif multiplicative is not None:
-            return base_output * multiplicative
-        else:
-            return base_output
+            if not is_dummy(self.out_scale_proj):
+                multiplicative = F.linear(hid, self.out_scale_proj, bias=multiplicative)
+        return torch.addcmul(additive, base_output, multiplicative)
 
     @property
     def rank(self) -> int:
@@ -109,12 +94,12 @@ class LowRankAdapter(nn.Module):
 class ACTIVATIONS(Enum):
     # enum of allowed activations for server-side adapters; linear activation is represented with DUMMY tensor
     # beware: these activations should be backwards compatible! new activations can only be added to the end of the list
-    relu, gelu, relu6, leaky_relu, sigmoid, tanh = range(1, 7)
+    linear, relu, gelu, relu6, leaky_relu, sigmoid, tanh = range(7)
 
 
 for act in list(ACTIVATIONS)[1:]:
     assert hasattr(F, act.name), act.name
 
-_ACTIVATIONS_BY_INDEX = {act.value: getattr(F, act.name) for act in ACTIVATIONS}
+_ACTIVATIONS_BY_INDEX = {act.value: getattr(F, act.name) for act in list(ACTIVATIONS)[1:]}
 _ACTIVATIONS_BY_INDEX[0] = lambda x: x