|
@@ -0,0 +1,418 @@
|
|
|
+"""PyTorch BLOOM model ."""
|
|
|
+
|
|
|
+from typing import Tuple
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.utils.checkpoint
|
|
|
+from torch import nn
|
|
|
+from torch.nn import CrossEntropyLoss, LayerNorm
|
|
|
+
|
|
|
+from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
|
|
|
+from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
|
|
+from transformers.modeling_utils import PreTrainedModel
|
|
|
+from transformers.utils import logging
|
|
|
+from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
|
+
|
|
|
+from src.layer import BloomBlock
|
|
|
+from src.ops import build_alibi_tensor
|
|
|
+
|
|
|
+logger = logging.get_logger(__name__)
|
|
|
+
|
|
|
+_CHECKPOINT_FOR_DOC = "bigscience/Bloom"
|
|
|
+_CONFIG_FOR_DOC = "BloomConfig"
|
|
|
+_TOKENIZER_FOR_DOC = "BloomTokenizer"
|
|
|
+
|
|
|
+
|
|
|
+class BloomPreTrainedModel(PreTrainedModel):
|
|
|
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
+ """
|
|
|
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
|
+ models.
|
|
|
+ """
|
|
|
+
|
|
|
+ config_class = BloomConfig
|
|
|
+ base_model_prefix = "transformer"
|
|
|
+ supports_gradient_checkpointing = True
|
|
|
+ _no_split_modules = ["BloomBlock"]
|
|
|
+
|
|
|
+ def __init__(self, *inputs, **kwargs):
|
|
|
+ super().__init__(*inputs, **kwargs)
|
|
|
+
|
|
|
+ def _init_weights(self, module):
|
|
|
+ """Initialize the weights."""
|
|
|
+ if isinstance(module, (nn.Linear)):
|
|
|
+ # Slightly different from the TF version which uses truncated_normal for initialization
|
|
|
+ # cf https://github.com/pytorch/pytorch/pull/5617
|
|
|
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
+ if module.bias is not None:
|
|
|
+ module.bias.data.zero_()
|
|
|
+ elif isinstance(module, nn.Embedding):
|
|
|
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
|
|
+ if module.padding_idx is not None:
|
|
|
+ module.weight.data[module.padding_idx].zero_()
|
|
|
+ elif isinstance(module, LayerNorm):
|
|
|
+ module.bias.data.zero_()
|
|
|
+ module.weight.data.fill_(1.0)
|
|
|
+
|
|
|
+ def _set_gradient_checkpointing(self, module, value=False):
|
|
|
+ if isinstance(module, BloomModel):
|
|
|
+ module.gradient_checkpointing = value
|
|
|
+
|
|
|
+
|
|
|
+BLOOM_START_DOCSTRING = r"""
|
|
|
+
|
|
|
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
|
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
|
|
|
+
|
|
|
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
|
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
|
+ and behavior.
|
|
|
+
|
|
|
+ Parameters:
|
|
|
+ config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
|
|
|
+ Initializing with a config file does not load the weights associated with the model, only the
|
|
|
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
|
+"""
|
|
|
+
|
|
|
+BLOOM_INPUTS_DOCSTRING = r"""
|
|
|
+ Args:
|
|
|
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
|
|
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
|
|
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
|
|
+ sequence tokens in the vocabulary.
|
|
|
+
|
|
|
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
|
|
+ `input_ids`.
|
|
|
+
|
|
|
+ Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
|
+ [`PreTrainedTokenizer.__call__`] for details.
|
|
|
+
|
|
|
+ [What are input IDs?](../glossary#input-ids)
|
|
|
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
|
|
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
|
|
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
|
|
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
|
|
|
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
+
|
|
|
+ - 1 for tokens that are **not masked**,
|
|
|
+ - 0 for tokens that are **masked**.
|
|
|
+
|
|
|
+ [What are attention masks?](../glossary#attention-mask)
|
|
|
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
|
+ config.max_position_embeddings - 1]`.
|
|
|
+
|
|
|
+ [What are position IDs?](../glossary#position-ids)
|
|
|
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
|
|
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
|
|
+
|
|
|
+ - 1 indicates the head is **not masked**,
|
|
|
+ - 0 indicates the head is **masked**.
|
|
|
+
|
|
|
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
|
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
|
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
|
+ model's internal embedding lookup matrix.
|
|
|
+
|
|
|
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
|
|
+ `past_key_values`).
|
|
|
+ use_cache (`bool`, *optional*):
|
|
|
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
|
+ `past_key_values`).
|
|
|
+ output_attentions (`bool`, *optional*):
|
|
|
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
|
+ tensors for more detail.
|
|
|
+ output_hidden_states (`bool`, *optional*):
|
|
|
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
|
+ more detail.
|
|
|
+ return_dict (`bool`, *optional*):
|
|
|
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
|
+"""
|
|
|
+
|
|
|
+
|
|
|
+@add_start_docstrings(
|
|
|
+ "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
|
|
|
+ BLOOM_START_DOCSTRING,
|
|
|
+)
|
|
|
+class BloomModel(BloomPreTrainedModel):
|
|
|
+ def __init__(self, config):
|
|
|
+ super().__init__(config)
|
|
|
+
|
|
|
+ self.embed_dim = config.hidden_size
|
|
|
+ self.n_head = config.n_head
|
|
|
+
|
|
|
+ # Embedding + LN Embedding
|
|
|
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
|
|
+
|
|
|
+ self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
|
+
|
|
|
+ # Transformer blocks
|
|
|
+ self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
|
|
|
+
|
|
|
+ # Final Layer Norm
|
|
|
+ self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
|
+
|
|
|
+ self.gradient_checkpointing = False
|
|
|
+
|
|
|
+ # Initialize weights and apply final processing
|
|
|
+ self.post_init()
|
|
|
+
|
|
|
+ def get_input_embeddings(self):
|
|
|
+ return self.word_embeddings
|
|
|
+
|
|
|
+ def set_input_embeddings(self, new_embeddings):
|
|
|
+ self.word_embeddings = new_embeddings
|
|
|
+
|
|
|
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
+ @add_code_sample_docstrings(
|
|
|
+ processor_class=_TOKENIZER_FOR_DOC,
|
|
|
+ checkpoint=_CHECKPOINT_FOR_DOC,
|
|
|
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
|
|
|
+ config_class=_CONFIG_FOR_DOC,
|
|
|
+ )
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids=None,
|
|
|
+ past_key_values=None,
|
|
|
+ attention_mask=None,
|
|
|
+ position_ids=None,
|
|
|
+ head_mask=None,
|
|
|
+ inputs_embeds=None,
|
|
|
+ use_cache=None,
|
|
|
+ output_attentions=None,
|
|
|
+ output_hidden_states=None,
|
|
|
+ return_dict=None,
|
|
|
+ ):
|
|
|
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
+ output_hidden_states = (
|
|
|
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
+ )
|
|
|
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+
|
|
|
+ if input_ids is not None and inputs_embeds is not None:
|
|
|
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
+ elif input_ids is not None:
|
|
|
+ input_shape = input_ids.size()
|
|
|
+ input_ids = input_ids.view(-1, input_shape[-1])
|
|
|
+ elif inputs_embeds is not None:
|
|
|
+ input_shape = inputs_embeds.size()[:-1]
|
|
|
+ else:
|
|
|
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
|
|
|
+
|
|
|
+ if past_key_values is None:
|
|
|
+ past_key_values = tuple([None] * len(self.h))
|
|
|
+
|
|
|
+ # Prepare head mask if needed
|
|
|
+ # 1.0 in head_mask indicate we keep the head
|
|
|
+ # attention_probs has shape bsz x n_head x N x N
|
|
|
+ # head_mask has shape n_layer x batch x n_head x N x N
|
|
|
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
|
|
+
|
|
|
+ if inputs_embeds is None:
|
|
|
+ inputs_embeds = self.word_embeddings(input_ids)
|
|
|
+
|
|
|
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
+
|
|
|
+ output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
+
|
|
|
+ presents = () if use_cache else None
|
|
|
+ all_self_attentions = () if output_attentions else None
|
|
|
+ all_hidden_states = () if output_hidden_states else None
|
|
|
+
|
|
|
+ # Compute alibi tensor: check build_alibi_tensor documentation
|
|
|
+ current_sequence_length = hidden_states.shape[1]
|
|
|
+ if past_key_values[0] is not None:
|
|
|
+ current_sequence_length += past_key_values[0][0].shape[1]
|
|
|
+ alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
|
|
|
+
|
|
|
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
+
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
+
|
|
|
+ if self.gradient_checkpointing and self.training:
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ logger.warning(
|
|
|
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
|
+ )
|
|
|
+ use_cache = False
|
|
|
+
|
|
|
+ def create_custom_forward(module):
|
|
|
+ def custom_forward(*inputs):
|
|
|
+ # None for past_key_value
|
|
|
+ return module(*inputs, use_cache, output_attentions, alibi)
|
|
|
+
|
|
|
+ return custom_forward
|
|
|
+
|
|
|
+ outputs = torch.utils.checkpoint.checkpoint(
|
|
|
+ create_custom_forward(block),
|
|
|
+ hidden_states,
|
|
|
+ None,
|
|
|
+ attention_mask,
|
|
|
+ head_mask[i],
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ outputs = block(
|
|
|
+ hidden_states,
|
|
|
+ layer_past=layer_past,
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ head_mask=head_mask[i],
|
|
|
+ use_cache=use_cache,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ alibi=alibi,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = outputs[0]
|
|
|
+ if use_cache is True:
|
|
|
+ presents = presents + (outputs[1],)
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
|
|
+
|
|
|
+ # Add last hidden state
|
|
|
+ hidden_states = self.ln_f(hidden_states)
|
|
|
+
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
+
|
|
|
+ hidden_states = hidden_states.view(output_shape)
|
|
|
+
|
|
|
+ if not return_dict:
|
|
|
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
|
|
+
|
|
|
+ return BaseModelOutputWithPastAndCrossAttentions(
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
+ past_key_values=presents,
|
|
|
+ hidden_states=all_hidden_states,
|
|
|
+ attentions=all_self_attentions,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@add_start_docstrings(
|
|
|
+ """
|
|
|
+ The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
+ embeddings).
|
|
|
+ """,
|
|
|
+ BLOOM_START_DOCSTRING,
|
|
|
+)
|
|
|
+class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
+
|
|
|
+ def __init__(self, config):
|
|
|
+ super().__init__(config)
|
|
|
+ self.transformer = BloomModel(config)
|
|
|
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
+
|
|
|
+ # Initialize weights and apply final processing
|
|
|
+ self.post_init()
|
|
|
+
|
|
|
+ def get_output_embeddings(self):
|
|
|
+ return self.lm_head
|
|
|
+
|
|
|
+ def set_output_embeddings(self, new_embeddings):
|
|
|
+ self.lm_head = new_embeddings
|
|
|
+
|
|
|
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
|
+ # only last token for inputs_ids if past is defined in kwargs
|
|
|
+ if past:
|
|
|
+ input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
+
|
|
|
+ attention_mask = kwargs.get("attention_mask", None)
|
|
|
+ position_ids = kwargs.get("position_ids", None)
|
|
|
+
|
|
|
+ if attention_mask is not None and position_ids is None:
|
|
|
+ # create position_ids on the fly for batch generation
|
|
|
+ position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
+ position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
+ if past:
|
|
|
+ position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
+ else:
|
|
|
+ position_ids = None
|
|
|
+ return {
|
|
|
+ "input_ids": input_ids,
|
|
|
+ "past_key_values": past,
|
|
|
+ "use_cache": kwargs.get("use_cache"),
|
|
|
+ "position_ids": position_ids,
|
|
|
+ "attention_mask": attention_mask,
|
|
|
+ }
|
|
|
+
|
|
|
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
+ @add_code_sample_docstrings(
|
|
|
+ processor_class=_TOKENIZER_FOR_DOC,
|
|
|
+ checkpoint=_CHECKPOINT_FOR_DOC,
|
|
|
+ output_type=CausalLMOutputWithCrossAttentions,
|
|
|
+ config_class=_CONFIG_FOR_DOC,
|
|
|
+ )
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids=None,
|
|
|
+ past_key_values=None,
|
|
|
+ attention_mask=None,
|
|
|
+ position_ids=None,
|
|
|
+ head_mask=None,
|
|
|
+ inputs_embeds=None,
|
|
|
+ labels=None,
|
|
|
+ use_cache=None,
|
|
|
+ output_attentions=None,
|
|
|
+ output_hidden_states=None,
|
|
|
+ return_dict=None,
|
|
|
+ ):
|
|
|
+ r"""
|
|
|
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
|
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
|
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
+ """
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+
|
|
|
+ transformer_outputs = self.transformer(
|
|
|
+ input_ids,
|
|
|
+ past_key_values=past_key_values,
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ position_ids=position_ids,
|
|
|
+ head_mask=head_mask,
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ use_cache=use_cache,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ output_hidden_states=output_hidden_states,
|
|
|
+ return_dict=return_dict,
|
|
|
+ )
|
|
|
+ hidden_states = transformer_outputs[0]
|
|
|
+
|
|
|
+ lm_logits = self.lm_head(hidden_states)
|
|
|
+
|
|
|
+ loss = None
|
|
|
+ if labels is not None:
|
|
|
+ # Shift so that tokens < n predict n
|
|
|
+ shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
+ shift_labels = labels[..., 1:].contiguous()
|
|
|
+ # Flatten the tokens
|
|
|
+ loss_fct = CrossEntropyLoss()
|
|
|
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
+
|
|
|
+ if not return_dict:
|
|
|
+ output = (lm_logits,) + transformer_outputs[1:]
|
|
|
+ return ((loss,) + output) if loss is not None else output
|
|
|
+
|
|
|
+ return CausalLMOutputWithCrossAttentions(
|
|
|
+ loss=loss,
|
|
|
+ logits=lm_logits,
|
|
|
+ past_key_values=transformer_outputs.past_key_values,
|
|
|
+ hidden_states=transformer_outputs.hidden_states,
|
|
|
+ attentions=transformer_outputs.attentions,
|
|
|
+ )
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
|
|
+ """
|
|
|
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
|
|
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
|
|
+ beam_idx at every generation step.
|
|
|
+ """
|
|
|
+ return tuple(
|
|
|
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
|
|
+ for layer_past in past
|
|
|
+ )
|