dbaranchuk 3 лет назад
Родитель
Сommit
bfeddb5463

Разница между файлами не показана из-за своего большого размера
+ 0 - 815
notebooks/deep_prompt_tuning_cola.ipynb


Разница между файлами не показана из-за своего большого размера
+ 0 - 864
notebooks/deep_prompt_tuning_sst2.ipynb


Разница между файлами не показана из-за своего большого размера
+ 0 - 810
notebooks/shallow_prompt_tuning_cola.ipynb


Разница между файлами не показана из-за своего большого размера
+ 0 - 954
notebooks/shallow_prompt_tuning_sst2.ipynb


+ 1 - 151
src/bloom/model.py

@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
 Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
 See commit history for authorship.
 """
-from typing import Tuple, Union, Optional
+from typing import Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -625,153 +625,3 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
             hidden_states=transformer_outputs.hidden_states,
             attentions=transformer_outputs.attentions,
         )
-
-
-class BloomPrefix(BloomModel):
-    """DistributedBloomModel with prefix tokens for prompt tuning"""
-
-    def __init__(self, config):
-        super().__init__(config)
-        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-        assert config.prompt_tuning_mode in ['deep', 'shallow']
-
-        self.pre_seq_len = config.pre_seq_len
-        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-        self.hidden_size = config.hidden_size
-        self.prompt_tuning_mode = config.prompt_tuning_mode
-
-        self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
-        if config.prompt_tuning_mode == 'deep':
-            self.intermediate_prompt_embeddings = nn.Embedding(
-                self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
-            )
-            self.intermediate_prompt_embeddings.weight.data.zero_()
-        
-    def get_prompt(self, batch_size: int) -> torch.Tensor:
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
-        prompts = self.prompt_embeddings(prefix_tokens)
-        
-        if hasattr(self, 'intermediate_prompt_embeddings'):
-            intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
-            intermediate_prompts = intermediate_prompts.view(
-                batch_size,
-                self.pre_seq_len,
-                len(self.h) - 1,
-                self.hidden_size
-            )
-            intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
-            return prompts, intermediate_prompts
-        else:
-            return prompts, None
-        
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor],
-        inputs_embeds: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor],
-        past_key_values=None,
-        position_ids=None,
-        head_mask=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        assert (
-            input_ids is None or inputs_embeds is None
-        ), "You cannot specify both input_ids and inputs_embeds at the same time"
-        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-
-        if position_ids is not None:
-            logger.warning("position_ids are ignored in this bloom implementation")
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        batch_size = inputs_embeds.shape[0]
-
-        # Updated Bloom Model forward
-        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-        output_hidden_states = (
-            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
-        )
-        use_cache = use_cache if use_cache is not None else self.config.use_cache
-        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-        if past_key_values is None:
-            past_key_values = tuple([None] * len(self.h))
-
-        # Prepare head mask if needed
-        # 1.0 in head_mask indicate we keep the head
-        # attention_probs has shape bsz x n_head x N x N
-        # head_mask has shape n_layer x batch x n_head x N x N
-        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        prompts, intermediate_prompts = self.get_prompt(batch_size)
-        inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
-        input_shape = inputs_embeds.size()[:-1]
-
-        hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
-
-        output_shape = input_shape + (hidden_states.size(-1),)
-
-        presents = () if use_cache else None
-        all_self_attentions = () if output_attentions else None
-        all_hidden_states = () if output_hidden_states else None
-
-        #############################################################
-        # Prompt Tuning
-        if attention_mask is not None and self.pre_seq_len > 0:
-            prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len, device=attention_mask.device)
-            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
-        #############################################################
-
-        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
-
-            if output_hidden_states:
-                all_hidden_states = all_hidden_states + (hidden_states,)
-
-            # Prompt Tuning
-            if i > 0 and intermediate_prompts is not None:
-                hidden_states[:, :self.pre_seq_len] += intermediate_prompts[i - 1]
-
-            outputs = block(
-                hidden_states,
-                layer_past=layer_past,
-                attention_mask=attention_mask,
-                head_mask=head_mask[i],
-                use_cache=use_cache,
-                output_attentions=output_attentions,
-                alibi=None,
-            )
-
-            hidden_states = outputs[0]
-            if use_cache is True:
-                presents = presents + (outputs[1],)
-
-            if output_attentions:
-                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
-
-        # Add last hidden state
-        hidden_states = self.ln_f(hidden_states)
-
-        if output_hidden_states:
-            all_hidden_states = all_hidden_states + (hidden_states,)
-
-        hidden_states = hidden_states.view(output_shape)
-
-        # Remove prefix
-        hidden_states = hidden_states[:, self.pre_seq_len: ]
-
-        if not return_dict:
-            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
-
-        return BaseModelOutputWithPastAndCrossAttentions(
-            last_hidden_state=hidden_states,
-            past_key_values=presents,
-            hidden_states=all_hidden_states,
-            attentions=all_self_attentions,
-        )

+ 0 - 111
src/bloom/ptune_v2_model.py

@@ -1,111 +0,0 @@
-"""
-PyTorch BLOOM model that implements several memory-efficient modes.
-Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
-See commit history for authorship.
-"""
-from typing import Tuple, Union, Optional
-
-import torch
-import torch.utils.checkpoint
-from torch import nn
-
-from src.bloom.model import BloomModel
-use_hivemind_log_handler("in_root_logger")
-logger = logging.get_logger(__file__)
-
-
-class PrefixEncoder(torch.nn.Module):
-    r'''
-    The torch.nn model to encode the prefix
-    Input shape: (batch-size, prefix-length)
-    Output shape: (batch-size, prefix-length, 2*layers*hidden)
-    '''
-    def __init__(self, config):
-        super().__init__()
-        self.prefix_projection = False
-        self.embedding = nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
-
-    def forward(self, prefix: torch.Tensor):
-        if self.prefix_projection:
-            prefix_tokens = self.embedding(prefix)
-            past_key_values = self.trans(prefix_tokens)
-        else:
-            past_key_values = self.embedding(prefix)
-        return past_key_values
-
-
-class BloomPrefixV2(BloomModel):
-    """DistributedBloomModel with prefix tokens for prompt tuning"""
-
-    def __init__(self, config):
-        super().__init__(config)
-        assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
-        assert config.prompt_tuning_mode == 'deep'
-
-        self.pre_seq_len = config.pre_seq_len
-        self.prefix_tokens = torch.arange(self.pre_seq_len).long()
-        
-        self.prefix_encoder = PrefixEncoder(config)
-        self.hidden_size = config.hidden_size 
-        # self.dropout = torch.nn.Dropout(0.0)
-
-    def get_prompt(self, batch_size):
-        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
-        past_key_values = self.prefix_encoder(prefix_tokens)
-        
-        # bsz, seqlen, _ = past_key_values.shape
-        past_key_values = past_key_values.view(
-            batch_size,
-            self.pre_seq_len,
-            len(self.h) * 2, 
-            self.n_head,
-            self.hidden_size // self.n_head
-        )
-        # past_key_values = self.dropout(past_key_values)
-        past_key_values = past_key_values.permute([2, 0, 1, 3, 4]).split(2)
-        return past_key_values
-
-    def forward(
-        self,
-        input_ids: Optional[torch.LongTensor],
-        inputs_embeds: Optional[torch.Tensor],
-        attention_mask: Optional[torch.Tensor],
-        past_key_values=None,
-        position_ids=None,
-        head_mask=None,
-        use_cache=None,
-        output_attentions=None,
-        output_hidden_states=None,
-        return_dict=None,
-    ):
-        assert (
-            input_ids is None or inputs_embeds is None
-        ), "You cannot specify both input_ids and inputs_embeds at the same time"
-        assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
-
-        if position_ids is not None:
-            logger.warning("position_ids are ignored in this bloom implementation")
-
-        if inputs_embeds is None:
-            inputs_embeds = self.word_embeddings(input_ids)
-
-        batch_size = inputs_embeds.shape[0]
-
-        if attention_mask is not None:
-            prefix_attention_mask = torch.ones(batch_size, self.prefix_length, device=attention_mask.device)
-            attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
-
-        past_key_values = self.get_prompt(batch_size=batch_size)
-
-        transformer_outputs = super().forward(
-            inputs_embeds=inputs_embeds, 
-            attention_mask=attention_mask,
-            past_key_values=past_key_values,
-            head_mask=head_mask, 
-            use_cache=use_cache,
-            output_attentions=output_attentions,
-            output_hidden_states=output_hidden_states,
-            return_dict=return_dict
-        )
-        return transformer_outputs
-        

Некоторые файлы не были показаны из-за большого количества измененных файлов