model.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. """
  2. PyTorch BLOOM model that implements several memory-efficient modes.
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. from typing import Optional, Tuple, Union
  7. import torch
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint
  10. from hivemind import use_hivemind_log_handler
  11. from torch import nn
  12. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  13. from transformers.file_utils import (
  14. add_code_sample_docstrings,
  15. add_start_docstrings,
  16. add_start_docstrings_to_model_forward,
  17. )
  18. from transformers.modeling_outputs import (
  19. BaseModelOutputWithPastAndCrossAttentions,
  20. CausalLMOutputWithCrossAttentions,
  21. SequenceClassifierOutputWithPast,
  22. )
  23. from transformers.modeling_utils import PreTrainedModel
  24. from transformers.models.bloom.configuration_bloom import BloomConfig
  25. from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
  26. from transformers.utils import logging
  27. from src.bloom.block import BloomBlock
  28. use_hivemind_log_handler("in_root_logger")
  29. logger = logging.get_logger(__file__)
  30. _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
  31. _CONFIG_FOR_DOC = "BloomConfig"
  32. _TOKENIZER_FOR_DOC = "BloomTokenizer"
  33. BLOOM_START_DOCSTRING = r"""
  34. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  35. library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
  36. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  37. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  38. and behavior.
  39. Parameters:
  40. config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
  41. Initializing with a config file does not load the weights associated with the model, only the
  42. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  43. """
  44. BLOOM_INPUTS_DOCSTRING = r"""
  45. Args:
  46. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  47. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  48. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  49. sequence tokens in the vocabulary.
  50. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  51. `input_ids`.
  52. Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  53. [`PreTrainedTokenizer.__call__`] for details.
  54. [What are input IDs?](../glossary#input-ids)
  55. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
  56. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  57. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  58. their past given to this model should not be passed as `input_ids` as they have already been computed.
  59. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  60. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  61. - 1 for tokens that are **not masked**,
  62. - 0 for tokens that are **masked**.
  63. [What are attention masks?](../glossary#attention-mask)
  64. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  65. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  66. config.max_position_embeddings - 1]`.
  67. [What are position IDs?](../glossary#position-ids)
  68. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  69. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  70. - 1 indicates the head is **not masked**,
  71. - 0 indicates the head is **masked**.
  72. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  73. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  74. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  75. model's internal embedding lookup matrix.
  76. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  77. `past_key_values`).
  78. use_cache (`bool`, *optional*):
  79. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  80. `past_key_values`).
  81. output_attentions (`bool`, *optional*):
  82. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  83. tensors for more detail.
  84. output_hidden_states (`bool`, *optional*):
  85. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  86. more detail.
  87. return_dict (`bool`, *optional*):
  88. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  89. """
  90. class _BloomPreTrainedModelWithModifiedDefaults(BloomPreTrainedModel):
  91. @classmethod
  92. def from_pretrained(cls, *args, low_cpu_mem_usage: Optional[bool] = None, **kwargs):
  93. if low_cpu_mem_usage is None:
  94. low_cpu_mem_usage = True
  95. return super().from_pretrained(*args, low_cpu_mem_usage=low_cpu_mem_usage, **kwargs)
  96. from_pretrained.__doc__ = BloomPreTrainedModel.from_pretrained.__doc__.replace(
  97. "low_cpu_mem_usage(`bool`, *optional*)",
  98. "low_cpu_mem_usage(`bool`, *optional*, defaults to `True` in Petals)",
  99. )
  100. @add_start_docstrings(
  101. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  102. BLOOM_START_DOCSTRING,
  103. )
  104. class BloomModel(_BloomPreTrainedModelWithModifiedDefaults):
  105. def __init__(self, config):
  106. super().__init__(config)
  107. assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
  108. self.embed_dim = config.hidden_size
  109. self.n_head = config.n_head
  110. # Embedding + LN Embedding
  111. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  112. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  113. # Transformer blocks
  114. self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
  115. # Final Layer Norm
  116. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  117. self.gradient_checkpointing = False
  118. # Initialize weights and apply final processing
  119. self.post_init()
  120. def get_input_embeddings(self):
  121. return self.word_embeddings
  122. def set_input_embeddings(self, new_embeddings):
  123. self.word_embeddings = new_embeddings
  124. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  125. @add_code_sample_docstrings(
  126. processor_class=_TOKENIZER_FOR_DOC,
  127. checkpoint=_CHECKPOINT_FOR_DOC,
  128. output_type=BaseModelOutputWithPastAndCrossAttentions,
  129. config_class=_CONFIG_FOR_DOC,
  130. )
  131. def forward(
  132. self,
  133. input_ids=None,
  134. past_key_values=None,
  135. attention_mask=None,
  136. position_ids=None,
  137. head_mask=None,
  138. inputs_embeds=None,
  139. use_cache=None,
  140. output_attentions=None,
  141. output_hidden_states=None,
  142. return_dict=None,
  143. ):
  144. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  145. output_hidden_states = (
  146. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  147. )
  148. use_cache = use_cache if use_cache is not None else self.config.use_cache
  149. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  150. if input_ids is not None and inputs_embeds is not None:
  151. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  152. if position_ids is not None:
  153. logger.warning("position_ids are ignored in this bloom implementation")
  154. elif input_ids is not None:
  155. input_shape = input_ids.size()
  156. input_ids = input_ids.view(-1, input_shape[-1])
  157. elif inputs_embeds is not None:
  158. input_shape = inputs_embeds.size()[:-1]
  159. else:
  160. raise ValueError("You have to specify either input_ids or inputs_embeds")
  161. if past_key_values is None:
  162. past_key_values = tuple([None] * len(self.h))
  163. # Prepare head mask if needed
  164. # 1.0 in head_mask indicate we keep the head
  165. # attention_probs has shape bsz x n_head x N x N
  166. # head_mask has shape n_layer x batch x n_head x N x N
  167. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  168. if inputs_embeds is None:
  169. inputs_embeds = self.word_embeddings(input_ids)
  170. # Note: it supports only float32 or bfloat16 inputs
  171. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  172. output_shape = input_shape + (hidden_states.size(-1),)
  173. presents = () if use_cache else None
  174. all_self_attentions = () if output_attentions else None
  175. all_hidden_states = () if output_hidden_states else None
  176. # Compute alibi tensor: check build_alibi_tensor documentation
  177. current_sequence_length = hidden_states.shape[1]
  178. if past_key_values and past_key_values[0]:
  179. current_sequence_length += past_key_values[0][0].shape[1]
  180. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  181. if output_hidden_states:
  182. all_hidden_states = all_hidden_states + (hidden_states,)
  183. if self.gradient_checkpointing and self.training:
  184. if use_cache:
  185. logger.warning(
  186. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  187. )
  188. use_cache = False
  189. def create_custom_forward(module):
  190. def custom_forward(*inputs):
  191. # None for past_key_value
  192. return module(*inputs, use_cache, output_attentions, alibi=None)
  193. return custom_forward
  194. outputs = torch.utils.checkpoint.checkpoint(
  195. create_custom_forward(block),
  196. hidden_states,
  197. None,
  198. attention_mask,
  199. head_mask[i],
  200. )
  201. else:
  202. outputs = block(
  203. hidden_states,
  204. layer_past=layer_past,
  205. attention_mask=attention_mask,
  206. head_mask=head_mask[i],
  207. use_cache=use_cache,
  208. output_attentions=output_attentions,
  209. alibi=None,
  210. )
  211. hidden_states = outputs[0]
  212. if use_cache is True:
  213. presents = presents + (outputs[1],)
  214. if output_attentions:
  215. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  216. # Add last hidden state
  217. hidden_states = self.ln_f(hidden_states)
  218. if output_hidden_states:
  219. all_hidden_states = all_hidden_states + (hidden_states,)
  220. hidden_states = hidden_states.view(output_shape)
  221. if not return_dict:
  222. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
  223. return BaseModelOutputWithPastAndCrossAttentions(
  224. last_hidden_state=hidden_states,
  225. past_key_values=presents,
  226. hidden_states=all_hidden_states,
  227. attentions=all_self_attentions,
  228. )
  229. @add_start_docstrings(
  230. """
  231. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  232. embeddings).
  233. """,
  234. BLOOM_START_DOCSTRING,
  235. )
  236. class BloomForCausalLM(_BloomPreTrainedModelWithModifiedDefaults):
  237. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  238. def __init__(self, config):
  239. super().__init__(config)
  240. self.transformer = BloomModel(config)
  241. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  242. # Initialize weights and apply final processing
  243. self.post_init()
  244. def get_output_embeddings(self):
  245. return self.lm_head
  246. def set_output_embeddings(self, new_embeddings):
  247. self.lm_head = new_embeddings
  248. def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
  249. # only last token for inputs_ids if past is defined in kwargs
  250. if past:
  251. input_ids = input_ids[:, -1].unsqueeze(-1)
  252. attention_mask = kwargs.get("attention_mask", None)
  253. position_ids = kwargs.get("position_ids", None)
  254. if attention_mask is not None and position_ids is None:
  255. # create position_ids on the fly for batch generation
  256. position_ids = attention_mask.long().cumsum(-1) - 1
  257. position_ids.masked_fill_(attention_mask == 0, 1)
  258. if past:
  259. position_ids = position_ids[:, -1].unsqueeze(-1)
  260. else:
  261. position_ids = None
  262. return {
  263. "input_ids": input_ids,
  264. "past_key_values": past,
  265. "use_cache": kwargs.get("use_cache"),
  266. "position_ids": position_ids,
  267. "attention_mask": attention_mask,
  268. }
  269. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  270. @add_code_sample_docstrings(
  271. processor_class=_TOKENIZER_FOR_DOC,
  272. checkpoint=_CHECKPOINT_FOR_DOC,
  273. output_type=CausalLMOutputWithCrossAttentions,
  274. config_class=_CONFIG_FOR_DOC,
  275. )
  276. def forward(
  277. self,
  278. input_ids=None,
  279. past_key_values=None,
  280. attention_mask=None,
  281. position_ids=None,
  282. head_mask=None,
  283. inputs_embeds=None,
  284. labels=None,
  285. use_cache=None,
  286. output_attentions=None,
  287. output_hidden_states=None,
  288. return_dict=None,
  289. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  290. r"""
  291. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  292. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  293. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  294. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  295. """
  296. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  297. transformer_outputs = self.transformer(
  298. input_ids,
  299. past_key_values=past_key_values,
  300. attention_mask=attention_mask,
  301. position_ids=position_ids,
  302. head_mask=head_mask,
  303. inputs_embeds=inputs_embeds,
  304. use_cache=use_cache,
  305. output_attentions=output_attentions,
  306. output_hidden_states=output_hidden_states,
  307. return_dict=return_dict,
  308. )
  309. hidden_states = transformer_outputs[0]
  310. lm_logits = self.lm_head(hidden_states)
  311. loss = None
  312. if labels is not None:
  313. # Shift so that tokens < n predict n
  314. shift_logits = lm_logits[..., :-1, :].contiguous()
  315. shift_labels = labels[..., 1:].contiguous()
  316. # Flatten the tokens
  317. loss_fct = CrossEntropyLoss()
  318. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  319. if not return_dict:
  320. output = (lm_logits,) + transformer_outputs[1:]
  321. return ((loss,) + output) if loss is not None else output
  322. return CausalLMOutputWithCrossAttentions(
  323. loss=loss,
  324. logits=lm_logits,
  325. past_key_values=transformer_outputs.past_key_values,
  326. hidden_states=transformer_outputs.hidden_states,
  327. attentions=transformer_outputs.attentions,
  328. )
  329. @staticmethod
  330. def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
  331. """
  332. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  333. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  334. beam_idx at every generation step.
  335. """
  336. return tuple(
  337. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  338. for layer_past in past
  339. )
  340. @add_start_docstrings(
  341. """
  342. The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
  343. embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
  344. In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
  345. """,
  346. BLOOM_START_DOCSTRING,
  347. )
  348. class LMHead(nn.Module):
  349. def __init__(self, config, word_embeddings: nn.Embedding):
  350. super().__init__()
  351. self.word_embeddings = word_embeddings
  352. self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
  353. @property
  354. def in_features(self) -> int:
  355. return self.word_embeddings.num_embeddings
  356. @property
  357. def out_features(self) -> int:
  358. return self.word_embeddings.embedding_dim
  359. @property
  360. def weight(self):
  361. return self.word_embeddings.weight
  362. @property
  363. def bias(self):
  364. return None
  365. def forward(self, hidden_states):
  366. word_embeddings = self.word_embeddings.weight
  367. # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
  368. if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
  369. lm_logits = self.chunked_forward(hidden_states)
  370. else:
  371. # Switch dtype in case word_embeddings are fp16/bf16
  372. hidden_states = hidden_states.to(word_embeddings.dtype)
  373. lm_logits = F.linear(hidden_states, word_embeddings)
  374. return lm_logits
  375. def chunked_forward(self, hidden_states):
  376. """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
  377. chunk_size: provides trade-off between efficiency and extra memory consumption.
  378. """
  379. assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
  380. word_embeddings = self.word_embeddings.weight
  381. num_embeddings = self.word_embeddings.num_embeddings
  382. hidden_states = hidden_states.float()
  383. output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
  384. for i in range(0, num_embeddings, self.chunk_size):
  385. chunk = word_embeddings[i : i + self.chunk_size].float()
  386. output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
  387. return output
  388. @add_start_docstrings(
  389. """
  390. The Bloom Model transformer with a sequence classification head on top (linear layer).
  391. [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  392. (e.g. GPT-1) do.
  393. Since it does classification on the last token, it requires to know the position of the last token. If a
  394. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  395. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  396. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  397. each row of the batch).
  398. """,
  399. BLOOM_START_DOCSTRING,
  400. )
  401. class BloomForSequenceClassification(_BloomPreTrainedModelWithModifiedDefaults):
  402. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  403. def __init__(self, config):
  404. super().__init__(config)
  405. self.num_labels = config.num_labels
  406. self.transformer = BloomModel(config)
  407. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  408. # Initialize weights and apply final processing
  409. self.post_init()
  410. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  411. @add_code_sample_docstrings(
  412. processor_class=_TOKENIZER_FOR_DOC,
  413. checkpoint=_CHECKPOINT_FOR_DOC,
  414. output_type=SequenceClassifierOutputWithPast,
  415. config_class=_CONFIG_FOR_DOC,
  416. )
  417. def forward(
  418. self,
  419. input_ids=None,
  420. past_key_values=None,
  421. attention_mask=None,
  422. position_ids=None,
  423. head_mask=None,
  424. inputs_embeds=None,
  425. labels=None,
  426. use_cache=None,
  427. output_attentions=None,
  428. output_hidden_states=None,
  429. return_dict=None,
  430. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  431. r"""
  432. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  433. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  434. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  435. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  436. """
  437. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  438. transformer_outputs = self.transformer(
  439. input_ids,
  440. past_key_values=past_key_values,
  441. attention_mask=attention_mask,
  442. position_ids=position_ids,
  443. head_mask=head_mask,
  444. inputs_embeds=inputs_embeds,
  445. use_cache=use_cache,
  446. output_attentions=output_attentions,
  447. output_hidden_states=output_hidden_states,
  448. return_dict=return_dict,
  449. )
  450. hidden_states = transformer_outputs[0]
  451. logits = self.score(hidden_states)
  452. if input_ids is not None:
  453. batch_size = input_ids.shape[0]
  454. else:
  455. batch_size = inputs_embeds.shape[0]
  456. if self.config.pad_token_id is None and batch_size != 1:
  457. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  458. if self.config.pad_token_id is None:
  459. sequence_lengths = -1
  460. else:
  461. if input_ids is not None:
  462. sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
  463. else:
  464. sequence_lengths = -1
  465. logger.warning(
  466. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  467. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  468. )
  469. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  470. loss = None
  471. if labels is not None:
  472. if self.config.problem_type is None:
  473. if self.num_labels == 1:
  474. self.config.problem_type = "regression"
  475. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  476. self.config.problem_type = "single_label_classification"
  477. else:
  478. self.config.problem_type = "multi_label_classification"
  479. if self.config.problem_type == "regression":
  480. loss_fct = MSELoss()
  481. if self.num_labels == 1:
  482. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  483. else:
  484. loss = loss_fct(pooled_logits, labels)
  485. elif self.config.problem_type == "single_label_classification":
  486. loss_fct = CrossEntropyLoss()
  487. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  488. elif self.config.problem_type == "multi_label_classification":
  489. loss_fct = BCEWithLogitsLoss()
  490. loss = loss_fct(pooled_logits, labels)
  491. if not return_dict:
  492. output = (pooled_logits,) + transformer_outputs[1:]
  493. return ((loss,) + output) if loss is not None else output
  494. return SequenceClassifierOutputWithPast(
  495. loss=loss,
  496. logits=pooled_logits,
  497. past_key_values=transformer_outputs.past_key_values,
  498. hidden_states=transformer_outputs.hidden_states,
  499. attentions=transformer_outputs.attentions,
  500. )