|
@@ -155,8 +155,9 @@ class BloomModel(BloomPreTrainedModel):
|
|
self.n_head = config.n_head
|
|
self.n_head = config.n_head
|
|
|
|
|
|
# Embedding + LN Embedding
|
|
# Embedding + LN Embedding
|
|
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
|
|
|
|
-
|
|
|
|
|
|
+
|
|
|
|
+ # TODO: @dbaranchuk make efficient fp16 on cpu (convert only word_embeddings!)
|
|
|
|
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) # dtype=config.torch_dtype
|
|
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
|
|
|
|
|
# Transformer blocks
|
|
# Transformer blocks
|
|
@@ -170,11 +171,18 @@ class BloomModel(BloomPreTrainedModel):
|
|
# Initialize weights and apply final processing
|
|
# Initialize weights and apply final processing
|
|
self.post_init()
|
|
self.post_init()
|
|
|
|
|
|
|
|
+ # Forbid accumulate grads for embeddings and layernorm
|
|
|
|
+ self.set_requires_grad(False)
|
|
|
|
+
|
|
def get_input_embeddings(self):
|
|
def get_input_embeddings(self):
|
|
return self.word_embeddings
|
|
return self.word_embeddings
|
|
|
|
|
|
def set_input_embeddings(self, new_embeddings):
|
|
def set_input_embeddings(self, new_embeddings):
|
|
self.word_embeddings = new_embeddings
|
|
self.word_embeddings = new_embeddings
|
|
|
|
+
|
|
|
|
+ def set_requires_grad(self, value):
|
|
|
|
+ for p in self.parameters():
|
|
|
|
+ p.requires_grad=value
|
|
|
|
|
|
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
@add_code_sample_docstrings(
|
|
@add_code_sample_docstrings(
|
|
@@ -227,7 +235,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
if inputs_embeds is None:
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
|
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
|
|
+ hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
@@ -311,23 +319,9 @@ class BloomModel(BloomPreTrainedModel):
|
|
""",
|
|
""",
|
|
BLOOM_START_DOCSTRING,
|
|
BLOOM_START_DOCSTRING,
|
|
)
|
|
)
|
|
-class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
|
|
|
+class BloomForCausalLM(BloomModel):
|
|
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
|
|
|
- def __init__(self, config):
|
|
|
|
- super().__init__(config)
|
|
|
|
- self.transformer = BloomModel(config)
|
|
|
|
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
-
|
|
|
|
- # Initialize weights and apply final processing
|
|
|
|
- self.post_init()
|
|
|
|
-
|
|
|
|
- def get_output_embeddings(self):
|
|
|
|
- return self.lm_head
|
|
|
|
-
|
|
|
|
- def set_output_embeddings(self, new_embeddings):
|
|
|
|
- self.lm_head = new_embeddings
|
|
|
|
-
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
|
# only last token for inputs_ids if past is defined in kwargs
|
|
# only last token for inputs_ids if past is defined in kwargs
|
|
if past:
|
|
if past:
|
|
@@ -381,7 +375,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
"""
|
|
"""
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
- transformer_outputs = self.transformer(
|
|
|
|
|
|
+ transformer_outputs = super().forward(
|
|
input_ids,
|
|
input_ids,
|
|
past_key_values=past_key_values,
|
|
past_key_values=past_key_values,
|
|
attention_mask=attention_mask,
|
|
attention_mask=attention_mask,
|
|
@@ -393,9 +387,11 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
output_hidden_states=output_hidden_states,
|
|
output_hidden_states=output_hidden_states,
|
|
return_dict=return_dict,
|
|
return_dict=return_dict,
|
|
)
|
|
)
|
|
- hidden_states = transformer_outputs[0]
|
|
|
|
|
|
|
|
- lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
|
+ # Switch dtype in case word_embeddings are fp16
|
|
|
|
+ word_embeddings = self.word_embeddings.weight.t()
|
|
|
|
+ hidden_states = transformer_outputs[0].to(word_embeddings.dtype)
|
|
|
|
+ lm_logits = (hidden_states @ word_embeddings).float()
|
|
|
|
|
|
loss = None
|
|
loss = None
|
|
if labels is not None:
|
|
if labels is not None:
|