Browse Source

generic version

justheuristic 3 years ago
parent
commit
8ff36dc741
1 changed files with 6 additions and 10 deletions
  1. 6 10
      src/peft_utils.py

+ 6 - 10
src/peft_utils.py

@@ -29,17 +29,13 @@ class TransformerBlockPEFT(nn.Module):
 
         # "deep" prompts, applied to the outputs of each layer (https://arxiv.org/abs/2110.07602)
         self.output_prompts = nn.Parameter(DUMMY)   # dummy or [batch_size or 1, seq_length_prefix, hid_size]
+        self.attention_query_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
+        self.attention_key_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
+        self.attention_value_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
+        self.attention_out_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
+        self.mlp_in_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
+        self.mlp_out_adapter = GenericAdapter(self.hidden_size, self.hidden_size)
 
-        self.attention_query_adapter = GenericAdapter(self.hidden_size)
-        self.attention_key_adapter = GenericAdapter(self.hidden_size)
-        self.attention_value_adapter = GenericAdapter(self.hidden_size)
-        self.attention_out_adapter = GenericAdapter(self.hidden_size)
-        self.mlp_in_adapter = GenericAdapter(self.hidden_size)
-
-        # output projection, applied to the residual layer after MLP
-        self.adapter_out_weight = nn.Parameter(DUMMY)  # [adapter_dim, hid_size or hid_size * 2]
-        self.adapter_out_bias = nn.Parameter(DUMMY)  # [hid_size]
-        self.adapter_out_scale = nn.Parameter(DUMMY)  # [hid_size]
 
 # planned:
 # strategy: define