|
@@ -10,10 +10,15 @@ import torch.nn.functional as F
|
|
import torch.utils.checkpoint
|
|
import torch.utils.checkpoint
|
|
from hivemind import use_hivemind_log_handler
|
|
from hivemind import use_hivemind_log_handler
|
|
from torch import nn
|
|
from torch import nn
|
|
-from torch.nn import CrossEntropyLoss, LayerNorm
|
|
|
|
|
|
+from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss, LayerNorm
|
|
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
|
|
from transformers.file_utils import (add_code_sample_docstrings, add_start_docstrings,
|
|
add_start_docstrings_to_model_forward)
|
|
add_start_docstrings_to_model_forward)
|
|
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
|
|
|
|
|
+from transformers.modeling_outputs import (
|
|
|
|
+ BaseModelOutputWithPastAndCrossAttentions,
|
|
|
|
+ CausalLMOutputWithCrossAttentions,
|
|
|
|
+ SequenceClassifierOutputWithPast,
|
|
|
|
+ TokenClassifierOutput,
|
|
|
|
+)
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.modeling_utils import PreTrainedModel
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
from transformers.models.bloom.configuration_bloom import BloomConfig
|
|
from transformers.utils import logging
|
|
from transformers.utils import logging
|
|
@@ -469,3 +474,130 @@ class LMHead(nn.Module):
|
|
chunk = word_embeddings[i: i + self.chunk_size].float()
|
|
chunk = word_embeddings[i: i + self.chunk_size].float()
|
|
output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
output[..., i: i + self.chunk_size] = F.linear(hidden_states, chunk)
|
|
return output
|
|
return output
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@add_start_docstrings(
|
|
|
|
+ """
|
|
|
|
+ The Bloom Model transformer with a sequence classification head on top (linear layer).
|
|
|
|
+ [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
|
|
|
+ (e.g. GPT-1) do.
|
|
|
|
+ Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
|
|
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
|
|
|
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
|
|
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
|
|
|
+ each row of the batch).
|
|
|
|
+ """,
|
|
|
|
+ BLOOM_START_DOCSTRING,
|
|
|
|
+)
|
|
|
|
+class BloomForSequenceClassification(BloomPreTrainedModel):
|
|
|
|
+ _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
|
|
|
|
+
|
|
|
|
+ def __init__(self, config):
|
|
|
|
+ super().__init__(config)
|
|
|
|
+ self.num_labels = config.num_labels
|
|
|
|
+ self.transformer = BloomModel(config)
|
|
|
|
+ self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
|
|
|
|
+
|
|
|
|
+ # Initialize weights and apply final processing
|
|
|
|
+ self.post_init()
|
|
|
|
+
|
|
|
|
+ @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
|
|
|
|
+ @add_code_sample_docstrings(
|
|
|
|
+ processor_class=_TOKENIZER_FOR_DOC,
|
|
|
|
+ checkpoint=_CHECKPOINT_FOR_DOC,
|
|
|
|
+ output_type=SequenceClassifierOutputWithPast,
|
|
|
|
+ config_class=_CONFIG_FOR_DOC,
|
|
|
|
+ )
|
|
|
|
+ def forward(
|
|
|
|
+ self,
|
|
|
|
+ input_ids=None,
|
|
|
|
+ past_key_values=None,
|
|
|
|
+ attention_mask=None,
|
|
|
|
+ position_ids=None,
|
|
|
|
+ head_mask=None,
|
|
|
|
+ inputs_embeds=None,
|
|
|
|
+ labels=None,
|
|
|
|
+ use_cache=None,
|
|
|
|
+ output_attentions=None,
|
|
|
|
+ output_hidden_states=None,
|
|
|
|
+ return_dict=None,
|
|
|
|
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
|
|
|
|
+ r"""
|
|
|
|
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
|
|
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
|
|
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
+
|
|
|
|
+ transformer_outputs = self.transformer(
|
|
|
|
+ input_ids,
|
|
|
|
+ past_key_values=past_key_values,
|
|
|
|
+ attention_mask=attention_mask,
|
|
|
|
+ position_ids=position_ids,
|
|
|
|
+ head_mask=head_mask,
|
|
|
|
+ inputs_embeds=inputs_embeds,
|
|
|
|
+ use_cache=use_cache,
|
|
|
|
+ output_attentions=output_attentions,
|
|
|
|
+ output_hidden_states=output_hidden_states,
|
|
|
|
+ return_dict=return_dict,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ hidden_states = transformer_outputs[0]
|
|
|
|
+ logits = self.score(hidden_states)
|
|
|
|
+
|
|
|
|
+ if input_ids is not None:
|
|
|
|
+ batch_size = input_ids.shape[0]
|
|
|
|
+ else:
|
|
|
|
+ batch_size = inputs_embeds.shape[0]
|
|
|
|
+
|
|
|
|
+ if self.config.pad_token_id is None and batch_size != 1:
|
|
|
|
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
|
|
|
|
+ if self.config.pad_token_id is None:
|
|
|
|
+ sequence_lengths = -1
|
|
|
|
+ else:
|
|
|
|
+ if input_ids is not None:
|
|
|
|
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
|
|
|
+ else:
|
|
|
|
+ sequence_lengths = -1
|
|
|
|
+ logger.warning(
|
|
|
|
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
|
|
|
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
|
|
|
+
|
|
|
|
+ loss = None
|
|
|
|
+ if labels is not None:
|
|
|
|
+ if self.config.problem_type is None:
|
|
|
|
+ if self.num_labels == 1:
|
|
|
|
+ self.config.problem_type = "regression"
|
|
|
|
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
|
|
|
+ self.config.problem_type = "single_label_classification"
|
|
|
|
+ else:
|
|
|
|
+ self.config.problem_type = "multi_label_classification"
|
|
|
|
+
|
|
|
|
+ if self.config.problem_type == "regression":
|
|
|
|
+ loss_fct = MSELoss()
|
|
|
|
+ if self.num_labels == 1:
|
|
|
|
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
|
|
+ else:
|
|
|
|
+ loss = loss_fct(pooled_logits, labels)
|
|
|
|
+ elif self.config.problem_type == "single_label_classification":
|
|
|
|
+ loss_fct = CrossEntropyLoss()
|
|
|
|
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
|
|
|
+ elif self.config.problem_type == "multi_label_classification":
|
|
|
|
+ loss_fct = BCEWithLogitsLoss()
|
|
|
|
+ loss = loss_fct(pooled_logits, labels)
|
|
|
|
+ if not return_dict:
|
|
|
|
+ output = (pooled_logits,) + transformer_outputs[1:]
|
|
|
|
+ return ((loss,) + output) if loss is not None else output
|
|
|
|
+
|
|
|
|
+ return SequenceClassifierOutputWithPast(
|
|
|
|
+ loss=loss,
|
|
|
|
+ logits=pooled_logits,
|
|
|
|
+ past_key_values=transformer_outputs.past_key_values,
|
|
|
|
+ hidden_states=transformer_outputs.hidden_states,
|
|
|
|
+ attentions=transformer_outputs.attentions,
|
|
|
|
+ )
|