model.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. """
  2. PyTorch BLOOM model that implements several memory-efficient modes.
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. from typing import Tuple, Union
  7. import torch
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint
  10. from hivemind import use_hivemind_log_handler
  11. from torch import nn
  12. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  13. from transformers.file_utils import (
  14. add_code_sample_docstrings,
  15. add_start_docstrings,
  16. add_start_docstrings_to_model_forward,
  17. )
  18. from transformers.modeling_outputs import (
  19. BaseModelOutputWithPastAndCrossAttentions,
  20. CausalLMOutputWithCrossAttentions,
  21. SequenceClassifierOutputWithPast,
  22. )
  23. from transformers.modeling_utils import PreTrainedModel
  24. from transformers.models.bloom.configuration_bloom import BloomConfig
  25. from transformers.utils import logging
  26. from src.bloom.block import BloomBlock
  27. use_hivemind_log_handler("in_root_logger")
  28. logger = logging.get_logger(__file__)
  29. _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
  30. _CONFIG_FOR_DOC = "BloomConfig"
  31. _TOKENIZER_FOR_DOC = "BloomTokenizer"
  32. class BloomPreTrainedModel(PreTrainedModel):
  33. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  34. """
  35. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  36. models.
  37. """
  38. config_class = BloomConfig
  39. base_model_prefix = "transformer"
  40. supports_gradient_checkpointing = True
  41. _no_split_modules = ["BloomBlock"]
  42. def __init__(self, *inputs, **kwargs):
  43. super().__init__(*inputs, **kwargs)
  44. def _init_weights(self, module):
  45. """Initialize the weights."""
  46. if isinstance(module, (nn.Linear)):
  47. # Slightly different from the TF version which uses truncated_normal for initialization
  48. # cf https://github.com/pytorch/pytorch/pull/5617
  49. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  50. if module.bias is not None:
  51. module.bias.data.zero_()
  52. elif isinstance(module, nn.Embedding):
  53. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  54. if module.padding_idx is not None:
  55. module.weight.data[module.padding_idx].zero_()
  56. elif isinstance(module, LayerNorm):
  57. module.bias.data.zero_()
  58. module.weight.data.fill_(1.0)
  59. def _set_gradient_checkpointing(self, module, value=False):
  60. if isinstance(module, BloomModel):
  61. module.gradient_checkpointing = value
  62. BLOOM_START_DOCSTRING = r"""
  63. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  64. library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
  65. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  66. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  67. and behavior.
  68. Parameters:
  69. config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
  70. Initializing with a config file does not load the weights associated with the model, only the
  71. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  72. """
  73. BLOOM_INPUTS_DOCSTRING = r"""
  74. Args:
  75. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  76. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  77. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  78. sequence tokens in the vocabulary.
  79. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  80. `input_ids`.
  81. Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  82. [`PreTrainedTokenizer.__call__`] for details.
  83. [What are input IDs?](../glossary#input-ids)
  84. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
  85. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  86. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  87. their past given to this model should not be passed as `input_ids` as they have already been computed.
  88. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  89. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  90. - 1 for tokens that are **not masked**,
  91. - 0 for tokens that are **masked**.
  92. [What are attention masks?](../glossary#attention-mask)
  93. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  94. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  95. config.max_position_embeddings - 1]`.
  96. [What are position IDs?](../glossary#position-ids)
  97. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  98. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  99. - 1 indicates the head is **not masked**,
  100. - 0 indicates the head is **masked**.
  101. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  102. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  103. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  104. model's internal embedding lookup matrix.
  105. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  106. `past_key_values`).
  107. use_cache (`bool`, *optional*):
  108. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  109. `past_key_values`).
  110. output_attentions (`bool`, *optional*):
  111. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  112. tensors for more detail.
  113. output_hidden_states (`bool`, *optional*):
  114. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  115. more detail.
  116. return_dict (`bool`, *optional*):
  117. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  118. """
  119. @add_start_docstrings(
  120. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  121. BLOOM_START_DOCSTRING,
  122. )
  123. class BloomModel(BloomPreTrainedModel):
  124. def __init__(self, config):
  125. super().__init__(config)
  126. assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
  127. self.embed_dim = config.hidden_size
  128. self.n_head = config.n_head
  129. # Embedding + LN Embedding
  130. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  131. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  132. # Transformer blocks
  133. self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
  134. # Final Layer Norm
  135. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  136. self.gradient_checkpointing = False
  137. # Initialize weights and apply final processing
  138. self.post_init()
  139. def get_input_embeddings(self):
  140. return self.word_embeddings
  141. def set_input_embeddings(self, new_embeddings):
  142. self.word_embeddings = new_embeddings
  143. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  144. @add_code_sample_docstrings(
  145. processor_class=_TOKENIZER_FOR_DOC,
  146. checkpoint=_CHECKPOINT_FOR_DOC,
  147. output_type=BaseModelOutputWithPastAndCrossAttentions,
  148. config_class=_CONFIG_FOR_DOC,
  149. )
  150. def forward(
  151. self,
  152. input_ids=None,
  153. past_key_values=None,
  154. attention_mask=None,
  155. position_ids=None,
  156. head_mask=None,
  157. inputs_embeds=None,
  158. use_cache=None,
  159. output_attentions=None,
  160. output_hidden_states=None,
  161. return_dict=None,
  162. ):
  163. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  164. output_hidden_states = (
  165. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  166. )
  167. use_cache = use_cache if use_cache is not None else self.config.use_cache
  168. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  169. if input_ids is not None and inputs_embeds is not None:
  170. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  171. if position_ids is not None:
  172. logger.warning("position_ids are ignored in this bloom implementation")
  173. elif input_ids is not None:
  174. input_shape = input_ids.size()
  175. input_ids = input_ids.view(-1, input_shape[-1])
  176. elif inputs_embeds is not None:
  177. input_shape = inputs_embeds.size()[:-1]
  178. else:
  179. raise ValueError("You have to specify either input_ids or inputs_embeds")
  180. if past_key_values is None:
  181. past_key_values = tuple([None] * len(self.h))
  182. # Prepare head mask if needed
  183. # 1.0 in head_mask indicate we keep the head
  184. # attention_probs has shape bsz x n_head x N x N
  185. # head_mask has shape n_layer x batch x n_head x N x N
  186. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  187. if inputs_embeds is None:
  188. inputs_embeds = self.word_embeddings(input_ids)
  189. # Note: it supports only float32 or bfloat16 inputs
  190. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  191. output_shape = input_shape + (hidden_states.size(-1),)
  192. presents = () if use_cache else None
  193. all_self_attentions = () if output_attentions else None
  194. all_hidden_states = () if output_hidden_states else None
  195. # Compute alibi tensor: check build_alibi_tensor documentation
  196. current_sequence_length = hidden_states.shape[1]
  197. if past_key_values and past_key_values[0]:
  198. current_sequence_length += past_key_values[0][0].shape[1]
  199. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  200. if output_hidden_states:
  201. all_hidden_states = all_hidden_states + (hidden_states,)
  202. if self.gradient_checkpointing and self.training:
  203. if use_cache:
  204. logger.warning(
  205. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  206. )
  207. use_cache = False
  208. def create_custom_forward(module):
  209. def custom_forward(*inputs):
  210. # None for past_key_value
  211. return module(*inputs, use_cache, output_attentions, alibi=None)
  212. return custom_forward
  213. outputs = torch.utils.checkpoint.checkpoint(
  214. create_custom_forward(block),
  215. hidden_states,
  216. None,
  217. attention_mask,
  218. head_mask[i],
  219. )
  220. else:
  221. outputs = block(
  222. hidden_states,
  223. layer_past=layer_past,
  224. attention_mask=attention_mask,
  225. head_mask=head_mask[i],
  226. use_cache=use_cache,
  227. output_attentions=output_attentions,
  228. alibi=None,
  229. )
  230. hidden_states = outputs[0]
  231. if use_cache is True:
  232. presents = presents + (outputs[1],)
  233. if output_attentions:
  234. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  235. # Add last hidden state
  236. hidden_states = self.ln_f(hidden_states)
  237. if output_hidden_states:
  238. all_hidden_states = all_hidden_states + (hidden_states,)
  239. hidden_states = hidden_states.view(output_shape)
  240. if not return_dict:
  241. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
  242. return BaseModelOutputWithPastAndCrossAttentions(
  243. last_hidden_state=hidden_states,
  244. past_key_values=presents,
  245. hidden_states=all_hidden_states,
  246. attentions=all_self_attentions,
  247. )
  248. @add_start_docstrings(
  249. """
  250. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  251. embeddings).
  252. """,
  253. BLOOM_START_DOCSTRING,
  254. )
  255. class BloomForCausalLM(BloomPreTrainedModel):
  256. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  257. def __init__(self, config):
  258. super().__init__(config)
  259. self.transformer = BloomModel(config)
  260. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  261. # Initialize weights and apply final processing
  262. self.post_init()
  263. def get_output_embeddings(self):
  264. return self.lm_head
  265. def set_output_embeddings(self, new_embeddings):
  266. self.lm_head = new_embeddings
  267. def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
  268. # only last token for inputs_ids if past is defined in kwargs
  269. if past:
  270. input_ids = input_ids[:, -1].unsqueeze(-1)
  271. attention_mask = kwargs.get("attention_mask", None)
  272. position_ids = kwargs.get("position_ids", None)
  273. if attention_mask is not None and position_ids is None:
  274. # create position_ids on the fly for batch generation
  275. position_ids = attention_mask.long().cumsum(-1) - 1
  276. position_ids.masked_fill_(attention_mask == 0, 1)
  277. if past:
  278. position_ids = position_ids[:, -1].unsqueeze(-1)
  279. else:
  280. position_ids = None
  281. return {
  282. "input_ids": input_ids,
  283. "past_key_values": past,
  284. "use_cache": kwargs.get("use_cache"),
  285. "position_ids": position_ids,
  286. "attention_mask": attention_mask,
  287. }
  288. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  289. @add_code_sample_docstrings(
  290. processor_class=_TOKENIZER_FOR_DOC,
  291. checkpoint=_CHECKPOINT_FOR_DOC,
  292. output_type=CausalLMOutputWithCrossAttentions,
  293. config_class=_CONFIG_FOR_DOC,
  294. )
  295. def forward(
  296. self,
  297. input_ids=None,
  298. past_key_values=None,
  299. attention_mask=None,
  300. position_ids=None,
  301. head_mask=None,
  302. inputs_embeds=None,
  303. labels=None,
  304. use_cache=None,
  305. output_attentions=None,
  306. output_hidden_states=None,
  307. return_dict=None,
  308. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  309. r"""
  310. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  311. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  312. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  313. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  314. """
  315. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  316. transformer_outputs = self.transformer(
  317. input_ids,
  318. past_key_values=past_key_values,
  319. attention_mask=attention_mask,
  320. position_ids=position_ids,
  321. head_mask=head_mask,
  322. inputs_embeds=inputs_embeds,
  323. use_cache=use_cache,
  324. output_attentions=output_attentions,
  325. output_hidden_states=output_hidden_states,
  326. return_dict=return_dict,
  327. )
  328. hidden_states = transformer_outputs[0]
  329. lm_logits = self.lm_head(hidden_states)
  330. loss = None
  331. if labels is not None:
  332. # Shift so that tokens < n predict n
  333. shift_logits = lm_logits[..., :-1, :].contiguous()
  334. shift_labels = labels[..., 1:].contiguous()
  335. # Flatten the tokens
  336. loss_fct = CrossEntropyLoss()
  337. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  338. if not return_dict:
  339. output = (lm_logits,) + transformer_outputs[1:]
  340. return ((loss,) + output) if loss is not None else output
  341. return CausalLMOutputWithCrossAttentions(
  342. loss=loss,
  343. logits=lm_logits,
  344. past_key_values=transformer_outputs.past_key_values,
  345. hidden_states=transformer_outputs.hidden_states,
  346. attentions=transformer_outputs.attentions,
  347. )
  348. @staticmethod
  349. def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
  350. """
  351. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  352. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  353. beam_idx at every generation step.
  354. """
  355. return tuple(
  356. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  357. for layer_past in past
  358. )
  359. @add_start_docstrings(
  360. """
  361. The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
  362. embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
  363. In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
  364. """,
  365. BLOOM_START_DOCSTRING,
  366. )
  367. class LMHead(nn.Module):
  368. def __init__(self, config, word_embeddings: nn.Embedding):
  369. super().__init__()
  370. self.word_embeddings = word_embeddings
  371. self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
  372. @property
  373. def in_features(self) -> int:
  374. return self.word_embeddings.num_embeddings
  375. @property
  376. def out_features(self) -> int:
  377. return self.word_embeddings.embedding_dim
  378. @property
  379. def weight(self):
  380. return self.word_embeddings.weight
  381. @property
  382. def bias(self):
  383. return None
  384. def forward(self, hidden_states):
  385. word_embeddings = self.word_embeddings.weight
  386. # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
  387. if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
  388. lm_logits = self.chunked_forward(hidden_states)
  389. else:
  390. # Switch dtype in case word_embeddings are fp16/bf16
  391. hidden_states = hidden_states.to(word_embeddings.dtype)
  392. lm_logits = F.linear(hidden_states, word_embeddings).float()
  393. return lm_logits
  394. def chunked_forward(self, hidden_states):
  395. """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
  396. chunk_size: provides trade-off between efficiency and extra memory consumption.
  397. """
  398. assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
  399. word_embeddings = self.word_embeddings.weight
  400. num_embeddings = self.word_embeddings.num_embeddings
  401. hidden_states = hidden_states.float()
  402. output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
  403. for i in range(0, num_embeddings, self.chunk_size):
  404. chunk = word_embeddings[i : i + self.chunk_size].float()
  405. output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
  406. return output
  407. @add_start_docstrings(
  408. """
  409. The Bloom Model transformer with a sequence classification head on top (linear layer).
  410. [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  411. (e.g. GPT-1) do.
  412. Since it does classification on the last token, it requires to know the position of the last token. If a
  413. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  414. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  415. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  416. each row of the batch).
  417. """,
  418. BLOOM_START_DOCSTRING,
  419. )
  420. class BloomForSequenceClassification(BloomPreTrainedModel):
  421. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  422. def __init__(self, config):
  423. super().__init__(config)
  424. self.num_labels = config.num_labels
  425. self.transformer = BloomModel(config)
  426. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  427. # Initialize weights and apply final processing
  428. self.post_init()
  429. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  430. @add_code_sample_docstrings(
  431. processor_class=_TOKENIZER_FOR_DOC,
  432. checkpoint=_CHECKPOINT_FOR_DOC,
  433. output_type=SequenceClassifierOutputWithPast,
  434. config_class=_CONFIG_FOR_DOC,
  435. )
  436. def forward(
  437. self,
  438. input_ids=None,
  439. past_key_values=None,
  440. attention_mask=None,
  441. position_ids=None,
  442. head_mask=None,
  443. inputs_embeds=None,
  444. labels=None,
  445. use_cache=None,
  446. output_attentions=None,
  447. output_hidden_states=None,
  448. return_dict=None,
  449. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  450. r"""
  451. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  452. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  453. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  454. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  455. """
  456. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  457. transformer_outputs = self.transformer(
  458. input_ids,
  459. past_key_values=past_key_values,
  460. attention_mask=attention_mask,
  461. position_ids=position_ids,
  462. head_mask=head_mask,
  463. inputs_embeds=inputs_embeds,
  464. use_cache=use_cache,
  465. output_attentions=output_attentions,
  466. output_hidden_states=output_hidden_states,
  467. return_dict=return_dict,
  468. )
  469. hidden_states = transformer_outputs[0]
  470. logits = self.score(hidden_states)
  471. if input_ids is not None:
  472. batch_size = input_ids.shape[0]
  473. else:
  474. batch_size = inputs_embeds.shape[0]
  475. if self.config.pad_token_id is None and batch_size != 1:
  476. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  477. if self.config.pad_token_id is None:
  478. sequence_lengths = -1
  479. else:
  480. if input_ids is not None:
  481. sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
  482. else:
  483. sequence_lengths = -1
  484. logger.warning(
  485. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  486. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  487. )
  488. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  489. loss = None
  490. if labels is not None:
  491. if self.config.problem_type is None:
  492. if self.num_labels == 1:
  493. self.config.problem_type = "regression"
  494. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  495. self.config.problem_type = "single_label_classification"
  496. else:
  497. self.config.problem_type = "multi_label_classification"
  498. if self.config.problem_type == "regression":
  499. loss_fct = MSELoss()
  500. if self.num_labels == 1:
  501. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  502. else:
  503. loss = loss_fct(pooled_logits, labels)
  504. elif self.config.problem_type == "single_label_classification":
  505. loss_fct = CrossEntropyLoss()
  506. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  507. elif self.config.problem_type == "multi_label_classification":
  508. loss_fct = BCEWithLogitsLoss()
  509. loss = loss_fct(pooled_logits, labels)
  510. if not return_dict:
  511. output = (pooled_logits,) + transformer_outputs[1:]
  512. return ((loss,) + output) if loss is not None else output
  513. return SequenceClassifierOutputWithPast(
  514. loss=loss,
  515. logits=pooled_logits,
  516. past_key_values=transformer_outputs.past_key_values,
  517. hidden_states=transformer_outputs.hidden_states,
  518. attentions=transformer_outputs.attentions,
  519. )