123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- """
- PyTorch BLOOM model that implements several memory-efficient modes.
- Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
- See commit history for authorship.
- """
- from typing import Tuple
- import torch
- import torch.utils.checkpoint
- from hivemind import use_hivemind_log_handler
- from torch import nn
- from torch.nn import CrossEntropyLoss, LayerNorm
- from transformers.file_utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- )
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
- from transformers.modeling_utils import PreTrainedModel
- from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
- from transformers.utils import logging
- from src.bloom.block import BloomBlock
- from src.bloom.ops import build_alibi_tensor
- use_hivemind_log_handler("in_root_logger")
- logger = logging.get_logger(__file__)
- _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
- _CONFIG_FOR_DOC = "DistributedBloomConfig"
- _TOKENIZER_FOR_DOC = "BloomTokenizer"
- class DistributedBloomConfig(_VanillaBloomConfig):
- compression: str = "none"
- slow_but_exact: bool = False
- class BloomPreTrainedModel(PreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = DistributedBloomConfig
- base_model_prefix = "transformer"
- supports_gradient_checkpointing = True
- _no_split_modules = ["BloomBlock"]
- def __init__(self, *inputs, **kwargs):
- super().__init__(*inputs, **kwargs)
- def _init_weights(self, module):
- """Initialize the weights."""
- if isinstance(module, (nn.Linear)):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- def _set_gradient_checkpointing(self, module, value=False):
- if isinstance(module, BloomModel):
- module.gradient_checkpointing = value
- BLOOM_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- BLOOM_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
- sequence tokens in the vocabulary.
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
- `input_ids`.
- Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
- their past given to this model should not be passed as `input_ids` as they have already been computed.
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
- `past_key_values`).
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
- BLOOM_START_DOCSTRING,
- )
- class BloomModel(BloomPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
- self.embed_dim = config.hidden_size
- self.n_head = config.n_head
- # Embedding + LN Embedding
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
- self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
- # Transformer blocks
- self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
- # Final Layer Norm
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.word_embeddings
- def set_input_embeddings(self, new_embeddings):
- self.word_embeddings = new_embeddings
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=BaseModelOutputWithPastAndCrossAttentions,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if past_key_values is None:
- past_key_values = tuple([None] * len(self.h))
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_head x N x N
- # head_mask has shape n_layer x batch x n_head x N x N
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
- output_shape = input_shape + (hidden_states.size(-1),)
- presents = () if use_cache else None
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- # Compute alibi tensor: check build_alibi_tensor documentation
- current_sequence_length = hidden_states.shape[1]
- if past_key_values[0] is not None:
- current_sequence_length += past_key_values[0][0].shape[1]
- alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- def create_custom_forward(module):
- def custom_forward(*inputs):
- # None for past_key_value
- return module(*inputs, use_cache, output_attentions, alibi)
- return custom_forward
- outputs = torch.utils.checkpoint.checkpoint(
- create_custom_forward(block),
- hidden_states,
- None,
- attention_mask,
- head_mask[i],
- )
- else:
- outputs = block(
- hidden_states,
- layer_past=layer_past,
- attention_mask=attention_mask,
- head_mask=head_mask[i],
- use_cache=use_cache,
- output_attentions=output_attentions,
- alibi=alibi,
- )
- hidden_states = outputs[0]
- if use_cache is True:
- presents = presents + (outputs[1],)
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
- # Add last hidden state
- hidden_states = self.ln_f(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- hidden_states = hidden_states.view(output_shape)
- if not return_dict:
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=presents,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- @add_start_docstrings(
- """
- The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
- embeddings).
- """,
- BLOOM_START_DOCSTRING,
- )
- class BloomForCausalLM(BloomPreTrainedModel):
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
- def __init__(self, config):
- super().__init__(config)
- self.transformer = BloomModel(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
- # only last token for inputs_ids if past is defined in kwargs
- if past:
- input_ids = input_ids[:, -1].unsqueeze(-1)
- attention_mask = kwargs.get("attention_mask", None)
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past:
- position_ids = position_ids[:, -1].unsqueeze(-1)
- else:
- position_ids = None
- return {
- "input_ids": input_ids,
- "past_key_values": past,
- "use_cache": kwargs.get("use_cache"),
- "position_ids": position_ids,
- "attention_mask": attention_mask,
- }
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- processor_class=_TOKENIZER_FOR_DOC,
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=CausalLMOutputWithCrossAttentions,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids=None,
- past_key_values=None,
- attention_mask=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- use_cache=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.transformer(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- lm_logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- if not return_dict:
- output = (lm_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=lm_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @staticmethod
- def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
- """
- return tuple(
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
- for layer_past in past
- )
|