|
@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
|
|
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
|
See commit history for authorship.
|
|
|
"""
|
|
|
-from typing import Tuple, Union, Optional
|
|
|
+from typing import Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
@@ -625,153 +625,3 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
|
|
|
hidden_states=transformer_outputs.hidden_states,
|
|
|
attentions=transformer_outputs.attentions,
|
|
|
)
|
|
|
-
|
|
|
-
|
|
|
-class BloomPrefix(BloomModel):
|
|
|
- """DistributedBloomModel with prefix tokens for prompt tuning"""
|
|
|
-
|
|
|
- def __init__(self, config):
|
|
|
- super().__init__(config)
|
|
|
- assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
|
|
|
- assert config.prompt_tuning_mode in ['deep', 'shallow']
|
|
|
-
|
|
|
- self.pre_seq_len = config.pre_seq_len
|
|
|
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
|
|
|
- self.hidden_size = config.hidden_size
|
|
|
- self.prompt_tuning_mode = config.prompt_tuning_mode
|
|
|
-
|
|
|
- self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
|
|
|
- if config.prompt_tuning_mode == 'deep':
|
|
|
- self.intermediate_prompt_embeddings = nn.Embedding(
|
|
|
- self.pre_seq_len, (config.num_hidden_layers - 1) * config.hidden_size
|
|
|
- )
|
|
|
- self.intermediate_prompt_embeddings.weight.data.zero_()
|
|
|
-
|
|
|
- def get_prompt(self, batch_size: int) -> torch.Tensor:
|
|
|
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.device)
|
|
|
- prompts = self.prompt_embeddings(prefix_tokens)
|
|
|
-
|
|
|
- if hasattr(self, 'intermediate_prompt_embeddings'):
|
|
|
- intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
|
|
|
- intermediate_prompts = intermediate_prompts.view(
|
|
|
- batch_size,
|
|
|
- self.pre_seq_len,
|
|
|
- len(self.h) - 1,
|
|
|
- self.hidden_size
|
|
|
- )
|
|
|
- intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
|
|
|
- return prompts, intermediate_prompts
|
|
|
- else:
|
|
|
- return prompts, None
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- input_ids: Optional[torch.LongTensor],
|
|
|
- inputs_embeds: Optional[torch.Tensor],
|
|
|
- attention_mask: Optional[torch.Tensor],
|
|
|
- past_key_values=None,
|
|
|
- position_ids=None,
|
|
|
- head_mask=None,
|
|
|
- use_cache=None,
|
|
|
- output_attentions=None,
|
|
|
- output_hidden_states=None,
|
|
|
- return_dict=None,
|
|
|
- ):
|
|
|
- assert (
|
|
|
- input_ids is None or inputs_embeds is None
|
|
|
- ), "You cannot specify both input_ids and inputs_embeds at the same time"
|
|
|
- assert input_ids is not None or inputs_embeds is not None, "You must specify either input_ids or inputs_embeds"
|
|
|
-
|
|
|
- if position_ids is not None:
|
|
|
- logger.warning("position_ids are ignored in this bloom implementation")
|
|
|
-
|
|
|
- if inputs_embeds is None:
|
|
|
- inputs_embeds = self.word_embeddings(input_ids)
|
|
|
-
|
|
|
- batch_size = inputs_embeds.shape[0]
|
|
|
-
|
|
|
- # Updated Bloom Model forward
|
|
|
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
- output_hidden_states = (
|
|
|
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
- )
|
|
|
- use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
-
|
|
|
- if past_key_values is None:
|
|
|
- past_key_values = tuple([None] * len(self.h))
|
|
|
-
|
|
|
- # Prepare head mask if needed
|
|
|
- # 1.0 in head_mask indicate we keep the head
|
|
|
- # attention_probs has shape bsz x n_head x N x N
|
|
|
- # head_mask has shape n_layer x batch x n_head x N x N
|
|
|
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
-
|
|
|
- if inputs_embeds is None:
|
|
|
- inputs_embeds = self.word_embeddings(input_ids)
|
|
|
-
|
|
|
- prompts, intermediate_prompts = self.get_prompt(batch_size)
|
|
|
- inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
|
|
|
- input_shape = inputs_embeds.size()[:-1]
|
|
|
-
|
|
|
- hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
|
|
-
|
|
|
- output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
-
|
|
|
- presents = () if use_cache else None
|
|
|
- all_self_attentions = () if output_attentions else None
|
|
|
- all_hidden_states = () if output_hidden_states else None
|
|
|
-
|
|
|
- #############################################################
|
|
|
- # Prompt Tuning
|
|
|
- if attention_mask is not None and self.pre_seq_len > 0:
|
|
|
- prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len, device=attention_mask.device)
|
|
|
- attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
|
|
|
- #############################################################
|
|
|
-
|
|
|
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
-
|
|
|
- if output_hidden_states:
|
|
|
- all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
-
|
|
|
- # Prompt Tuning
|
|
|
- if i > 0 and intermediate_prompts is not None:
|
|
|
- hidden_states[:, :self.pre_seq_len] += intermediate_prompts[i - 1]
|
|
|
-
|
|
|
- outputs = block(
|
|
|
- hidden_states,
|
|
|
- layer_past=layer_past,
|
|
|
- attention_mask=attention_mask,
|
|
|
- head_mask=head_mask[i],
|
|
|
- use_cache=use_cache,
|
|
|
- output_attentions=output_attentions,
|
|
|
- alibi=None,
|
|
|
- )
|
|
|
-
|
|
|
- hidden_states = outputs[0]
|
|
|
- if use_cache is True:
|
|
|
- presents = presents + (outputs[1],)
|
|
|
-
|
|
|
- if output_attentions:
|
|
|
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
-
|
|
|
- # Add last hidden state
|
|
|
- hidden_states = self.ln_f(hidden_states)
|
|
|
-
|
|
|
- if output_hidden_states:
|
|
|
- all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
-
|
|
|
- hidden_states = hidden_states.view(output_shape)
|
|
|
-
|
|
|
- # Remove prefix
|
|
|
- hidden_states = hidden_states[:, self.pre_seq_len: ]
|
|
|
-
|
|
|
- if not return_dict:
|
|
|
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
-
|
|
|
- return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
- last_hidden_state=hidden_states,
|
|
|
- past_key_values=presents,
|
|
|
- hidden_states=all_hidden_states,
|
|
|
- attentions=all_self_attentions,
|
|
|
- )
|