Преглед изворни кода

add deep p-tuning only for _non-distributed_ case

dbaranchuk пре 3 година
родитељ
комит
81422fdfc2
1 измењених фајлова са 164 додато и 7 уклоњено
  1. 164 7
      src/bloom/model.py

+ 164 - 7
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple, Union
+from typing import Tuple, Union, Optional
 
 import torch
 import torch.nn.functional as F
@@ -237,11 +237,6 @@ class BloomModel(BloomPreTrainedModel):
         all_self_attentions = () if output_attentions else None
         all_hidden_states = () if output_hidden_states else None
 
-        # Compute alibi tensor: check build_alibi_tensor documentation
-        current_sequence_length = hidden_states.shape[1]
-        if past_key_values and past_key_values[0]:
-            current_sequence_length += past_key_values[0][0].shape[1]
-
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
             if output_hidden_states:
@@ -512,11 +507,22 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
     def __init__(self, config):
         super().__init__(config)
         self.num_labels = config.num_labels
-        self.transformer = BloomModel(config)
+
+        config.pre_seq_len = 16
+        config.prompt_tuning_mode = 'deep'
+
+        if config.pre_seq_len > 0:
+            self.transformer = BloomPrefix(config)
+        else:
+            self.transformer = BloomModel(config)
+
+        self.pooled_dropout = nn.Dropout(0.0)#pooled_dropout)
         self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
 
         # Initialize weights and apply final processing
         self.post_init()
+        if hasattr(self.transformer, 'intermediate_prompt_embeddings'):
+            self.transformer.intermediate_prompt_embeddings.weight.data.zero_()
 
     @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
     @add_code_sample_docstrings(
@@ -584,6 +590,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
                 )
 
         pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+        pooled_logits = self.pooled_dropout(pooled_logits)
 
         loss = None
         if labels is not None:
@@ -618,3 +625,153 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
             hidden_states=transformer_outputs.hidden_states,
             attentions=transformer_outputs.attentions,
         )
+
+
+class BloomPrefix(BloomModel):
+    """DistributedBloomModel with prefix tokens for prompt tuning"""
+
+    def __init__(self, config):
+        super().__init__(config)
+        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
+        assert config.prompt_tuning_mode in ['deep', 'shallow']
+
+        self.pre_seq_len = config.pre_seq_len
+        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+        self.hidden_size = config.hidden_size
+        self.prompt_tuning_mode = config.prompt_tuning_mode
+
+        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
+        if config.prompt_tuning_mode == 'deep':
+            self.intermediate_prompt_embeddings = nn.Embedding(
+                self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
+            )
+            self.intermediate_prompt_embeddings.weight.data.zero_()
+        
+    def get_prompt(self, batch_size: int) -> torch.Tensor:
+        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
+        prompts = self.prompt_embeddings(prefix_tokens)
+        
+        if hasattr(self, 'intermediate_prompt_embeddings'):
+            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
+            intermediate_prompts = intermediate_prompts.view(
+                batch_size,
+                self.pre_seq_len,
+                len(self.h) - 1,
+                self.hidden_size
+            )
+            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
+            return prompts, intermediate_prompts
+        else:
+            return prompts, None
+        
+    def forward(
+        self,
+        input_ids: Optional[torch.LongTensor],
+        inputs_embeds: Optional[torch.Tensor],
+        attention_mask: Optional[torch.Tensor],
+        past_key_values=None,
+        position_ids=None,
+        head_mask=None,
+        use_cache=None,
+        output_attentions=None,
+        output_hidden_states=None,
+        return_dict=None,
+    ):
+        assert (
+            input_ids is None or inputs_embeds is None
+        ), "You cannot specify both input_ids and inputs_embeds at the same time"
+        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
+
+        if position_ids is not None:
+            logger.warning("position_ids are ignored in this bloom implementation")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        batch_size = inputs_embeds.shape[0]
+
+        # Updated Bloom Model forward
+        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else self.config.use_cache
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        if past_key_values is None:
+            past_key_values = tuple([None] * len(self.h))
+
+        # Prepare head mask if needed
+        # 1.0 in head_mask indicate we keep the head
+        # attention_probs has shape bsz x n_head x N x N
+        # head_mask has shape n_layer x batch x n_head x N x N
+        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+        if inputs_embeds is None:
+            inputs_embeds = self.word_embeddings(input_ids)
+
+        prompts, intermediate_prompts = self.get_prompt(batch_size)
+        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        input_shape = inputs_embeds.size()[:-1]
+
+        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
+
+        output_shape = input_shape + (hidden_states.size(-1),)
+
+        presents = () if use_cache else None
+        all_self_attentions = () if output_attentions else None
+        all_hidden_states = () if output_hidden_states else None
+
+        #############################################################
+        # Prompt Tuning
+        if attention_mask is not None and self.pre_seq_len > 0:
+            prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len, device=attention_mask.device)
+            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+        #############################################################
+
+        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+            if output_hidden_states:
+                all_hidden_states = all_hidden_states + (hidden_states,)
+
+            # Prompt Tuning
+            if i > 0 and intermediate_prompts is not None:
+                hidden_states[:, :self.pre_seq_len] += intermediate_prompts[i - 1]
+
+            outputs = block(
+                hidden_states,
+                layer_past=layer_past,
+                attention_mask=attention_mask,
+                head_mask=head_mask[i],
+                use_cache=use_cache,
+                output_attentions=output_attentions,
+                alibi=None,
+            )
+
+            hidden_states = outputs[0]
+            if use_cache is True:
+                presents = presents + (outputs[1],)
+
+            if output_attentions:
+                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+        # Add last hidden state
+        hidden_states = self.ln_f(hidden_states)
+
+        if output_hidden_states:
+            all_hidden_states = all_hidden_states + (hidden_states,)
+
+        hidden_states = hidden_states.view(output_shape)
+
+        # Remove prefix
+        hidden_states = hidden_states[:, self.pre_seq_len: ]
+
+        if not return_dict:
+            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+        return BaseModelOutputWithPastAndCrossAttentions(
+            last_hidden_state=hidden_states,
+            past_key_values=presents,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attentions,
+        )