|
@@ -3,7 +3,7 @@ PyTorch BLOOM model that implements several memory-efficient modes.
|
|
|
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
|
|
|
See commit history for authorship.
|
|
|
"""
|
|
|
-from typing import Tuple
|
|
|
+from typing import Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
@@ -165,19 +165,12 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
|
|
|
- # Forbid accumulate grads for embeddings and layernorm
|
|
|
- self.set_requires_grad(False)
|
|
|
-
|
|
|
def get_input_embeddings(self):
|
|
|
return self.word_embeddings
|
|
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
|
self.word_embeddings = new_embeddings
|
|
|
|
|
|
- def set_requires_grad(self, value):
|
|
|
- for p in self.parameters():
|
|
|
- p.requires_grad = value
|
|
|
-
|
|
|
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
@add_code_sample_docstrings(
|
|
|
processor_class=_TOKENIZER_FOR_DOC,
|
|
@@ -319,14 +312,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
|
self.transformer = BloomModel(config)
|
|
|
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
+
|
|
|
# Initialize weights and apply final processing
|
|
|
self.post_init()
|
|
|
|
|
|
def get_output_embeddings(self):
|
|
|
- return self.transformer.word_embeddings
|
|
|
+ return self.lm_head
|
|
|
|
|
|
def set_output_embeddings(self, new_embeddings):
|
|
|
- self.transformer.word_embeddings.weight = new_embeddings.weight
|
|
|
+ self.lm_head = new_embeddings
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
|
# only last token for inputs_ids if past is defined in kwargs
|
|
@@ -359,7 +354,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
output_type=CausalLMOutputWithCrossAttentions,
|
|
|
config_class=_CONFIG_FOR_DOC,
|
|
|
)
|
|
|
- def forward(self, input_ids=None, labels=None, return_dict=None, **kwargs):
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ input_ids=None,
|
|
|
+ past_key_values=None,
|
|
|
+ attention_mask=None,
|
|
|
+ position_ids=None,
|
|
|
+ head_mask=None,
|
|
|
+ inputs_embeds=None,
|
|
|
+ labels=None,
|
|
|
+ use_cache=None,
|
|
|
+ output_attentions=None,
|
|
|
+ output_hidden_states=None,
|
|
|
+ return_dict=None,
|
|
|
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
|
|
r"""
|
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
|
@@ -367,12 +375,22 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
|
|
"""
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
- transformer_outputs = self.transformer.forward(input_ids=input_ids, return_dict=return_dict, **kwargs)
|
|
|
- word_embeddings = self.transformer.word_embeddings.weight
|
|
|
|
|
|
- # Switch dtype in case word_embeddings are fp16/bf16
|
|
|
- hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
|
|
|
- lm_logits = F.linear(hidden_states, word_embeddings).float()
|
|
|
+ transformer_outputs = self.transformer(
|
|
|
+ input_ids,
|
|
|
+ past_key_values=past_key_values,
|
|
|
+ attention_mask=attention_mask,
|
|
|
+ position_ids=position_ids,
|
|
|
+ head_mask=head_mask,
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
+ use_cache=use_cache,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ output_hidden_states=output_hidden_states,
|
|
|
+ return_dict=return_dict,
|
|
|
+ )
|
|
|
+ hidden_states = transformer_outputs[0]
|
|
|
+
|
|
|
+ lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
@@ -406,3 +424,48 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
|
|
for layer_past in past
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+@add_start_docstrings(
|
|
|
+ """
|
|
|
+ The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
|
|
|
+ embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
|
|
|
+ In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
|
|
|
+ """,
|
|
|
+ BLOOM_START_DOCSTRING,
|
|
|
+)
|
|
|
+class LMHead(nn.Module):
|
|
|
+ def __init__(self, config, word_embeddings: nn.Embedding):
|
|
|
+ super().__init__()
|
|
|
+ self.word_embeddings = word_embeddings
|
|
|
+ self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
|
|
|
+
|
|
|
+ def forward(self, hidden_states):
|
|
|
+ word_embeddings = self.word_embeddings.weight
|
|
|
+
|
|
|
+ # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
|
|
|
+ if word_embeddings.dtype in [torch.float16, torch.bfloat16] and \
|
|
|
+ word_embeddings.device.type == 'cpu':
|
|
|
+ lm_logits = self.chunked_forward(hidden_states)
|
|
|
+ else:
|
|
|
+ # Switch dtype in case word_embeddings are fp16/bf16
|
|
|
+ hidden_states = hidden_states.to(word_embeddings.dtype)
|
|
|
+ lm_logits = F.linear(hidden_states, word_embeddings).float()
|
|
|
+ return lm_logits
|
|
|
+
|
|
|
+ def chunked_forward(self, hidden_states):
|
|
|
+ """ Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
|
|
|
+ chunk_size: provides trade-off between efficiency and extra memory consumption.
|
|
|
+ """
|
|
|
+ assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
|
|
|
+
|
|
|
+ word_embeddings = self.word_embeddings.weight
|
|
|
+ num_embeddings = self.word_embeddings.num_embeddings
|
|
|
+
|
|
|
+ hidden_states = hidden_states.float()
|
|
|
+ output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
|
|
|
+
|
|
|
+ for i in range(0, num_embeddings, self.chunk_size):
|
|
|
+ chunk = word_embeddings[i: i + self.chunk_size].float()
|
|
|
+ output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
|
+ return output
|