|
@@ -4,8 +4,6 @@ Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e
|
|
|
See commit history for authorship.
|
|
|
"""
|
|
|
|
|
|
-from typing import Tuple
|
|
|
-
|
|
|
import torch
|
|
|
import torch.utils.checkpoint
|
|
|
from hivemind import use_hivemind_log_handler
|
|
@@ -155,8 +153,9 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
self.n_head = config.n_head
|
|
|
|
|
|
# Embedding + LN Embedding
|
|
|
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
|
|
-
|
|
|
+
|
|
|
+ # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
|
|
|
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
|
|
|
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
# Transformer blocks
|
|
@@ -170,11 +169,18 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
|
|
|
+ # Forbid accumulate grads for embeddings and layernorm
|
|
|
+ self.set_requires_grad(False)
|
|
|
+
|
|
|
def get_input_embeddings(self):
|
|
|
return self.word_embeddings
|
|
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
|
self.word_embeddings = new_embeddings
|
|
|
+
|
|
|
+ def set_requires_grad(self, value):
|
|
|
+ for p in self.parameters():
|
|
|
+ p.requires_grad=value
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
@add_code_sample_docstrings(
|
|
@@ -227,7 +233,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
if inputs_embeds is None:
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
@@ -306,126 +312,17 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
|
|
|
@add_start_docstrings(
|
|
|
"""
|
|
|
- The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
|
|
- embeddings).
|
|
|
+ The Bloom interface for various applications, e.g., inference, classification...
|
|
|
""",
|
|
|
BLOOM_START_DOCSTRING,
|
|
|
)
|
|
|
-class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
+class BloomForYou(BloomPreTrainedModel):
|
|
|
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
|
|
|
def __init__(self, config):
|
|
|
- super().__init__(config)
|
|
|
- self.transformer = BloomModel(config)
|
|
|
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
-
|
|
|
- # Initialize weights and apply final processing
|
|
|
- self.post_init()
|
|
|
-
|
|
|
- def get_output_embeddings(self):
|
|
|
- return self.lm_head
|
|
|
-
|
|
|
- def set_output_embeddings(self, new_embeddings):
|
|
|
- self.lm_head = new_embeddings
|
|
|
-
|
|
|
- def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
|
- # only last token for inputs_ids if past is defined in kwargs
|
|
|
- if past:
|
|
|
- input_ids = input_ids[:, -1].unsqueeze(-1)
|
|
|
-
|
|
|
- attention_mask = kwargs.get("attention_mask", None)
|
|
|
- position_ids = kwargs.get("position_ids", None)
|
|
|
-
|
|
|
- if attention_mask is not None and position_ids is None:
|
|
|
- # create position_ids on the fly for batch generation
|
|
|
- position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
- position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
- if past:
|
|
|
- position_ids = position_ids[:, -1].unsqueeze(-1)
|
|
|
- else:
|
|
|
- position_ids = None
|
|
|
- return {
|
|
|
- "input_ids": input_ids,
|
|
|
- "past_key_values": past,
|
|
|
- "use_cache": kwargs.get("use_cache"),
|
|
|
- "position_ids": position_ids,
|
|
|
- "attention_mask": attention_mask,
|
|
|
- }
|
|
|
-
|
|
|
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
- @add_code_sample_docstrings(
|
|
|
- processor_class=_TOKENIZER_FOR_DOC,
|
|
|
- checkpoint=_CHECKPOINT_FOR_DOC,
|
|
|
- output_type=CausalLMOutputWithCrossAttentions,
|
|
|
- config_class=_CONFIG_FOR_DOC,
|
|
|
- )
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- input_ids=None,
|
|
|
- past_key_values=None,
|
|
|
- attention_mask=None,
|
|
|
- position_ids=None,
|
|
|
- head_mask=None,
|
|
|
- inputs_embeds=None,
|
|
|
- labels=None,
|
|
|
- use_cache=None,
|
|
|
- output_attentions=None,
|
|
|
- output_hidden_states=None,
|
|
|
- return_dict=None,
|
|
|
- ):
|
|
|
- r"""
|
|
|
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
|
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
|
|
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
- """
|
|
|
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
-
|
|
|
- transformer_outputs = self.transformer(
|
|
|
- input_ids,
|
|
|
- past_key_values=past_key_values,
|
|
|
- attention_mask=attention_mask,
|
|
|
- position_ids=position_ids,
|
|
|
- head_mask=head_mask,
|
|
|
- inputs_embeds=inputs_embeds,
|
|
|
- use_cache=use_cache,
|
|
|
- output_attentions=output_attentions,
|
|
|
- output_hidden_states=output_hidden_states,
|
|
|
- return_dict=return_dict,
|
|
|
- )
|
|
|
- hidden_states = transformer_outputs[0]
|
|
|
-
|
|
|
- lm_logits = self.lm_head(hidden_states)
|
|
|
+ super().__init__(config)
|
|
|
+ self.transformer = BloomModel(config)
|
|
|
+ self.lm_head = None
|
|
|
|
|
|
- loss = None
|
|
|
- if labels is not None:
|
|
|
- # Shift so that tokens < n predict n
|
|
|
- shift_logits = lm_logits[..., :-1, :].contiguous()
|
|
|
- shift_labels = labels[..., 1:].contiguous()
|
|
|
- # Flatten the tokens
|
|
|
- loss_fct = CrossEntropyLoss()
|
|
|
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
|
|
-
|
|
|
- if not return_dict:
|
|
|
- output = (lm_logits,) + transformer_outputs[1:]
|
|
|
- return ((loss,) + output) if loss is not None else output
|
|
|
-
|
|
|
- return CausalLMOutputWithCrossAttentions(
|
|
|
- loss=loss,
|
|
|
- logits=lm_logits,
|
|
|
- past_key_values=transformer_outputs.past_key_values,
|
|
|
- hidden_states=transformer_outputs.hidden_states,
|
|
|
- attentions=transformer_outputs.attentions,
|
|
|
- )
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
|
|
- """
|
|
|
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
|
|
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
|
|
- beam_idx at every generation step.
|
|
|
- """
|
|
|
- return tuple(
|
|
|
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
|
|
- for layer_past in past
|
|
|
- )
|
|
|
+ # Initialize weights and apply final processing
|
|
|
+ self.post_init()
|