model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. """PyTorch BLOOM model ."""
  2. from typing import Tuple
  3. import torch
  4. import torch.utils.checkpoint
  5. from torch import nn
  6. from torch.nn import CrossEntropyLoss, LayerNorm
  7. from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
  8. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  9. from transformers.modeling_utils import PreTrainedModel
  10. from transformers.utils import logging
  11. from transformers.models.bloom.configuration_bloom import BloomConfig
  12. from src.layer import BloomBlock
  13. from src.ops import build_alibi_tensor
  14. logger = logging.get_logger(__name__)
  15. _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
  16. _CONFIG_FOR_DOC = "BloomConfig"
  17. _TOKENIZER_FOR_DOC = "BloomTokenizer"
  18. class BloomPreTrainedModel(PreTrainedModel):
  19. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  20. """
  21. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  22. models.
  23. """
  24. config_class = BloomConfig
  25. base_model_prefix = "transformer"
  26. supports_gradient_checkpointing = True
  27. _no_split_modules = ["BloomBlock"]
  28. def __init__(self, *inputs, **kwargs):
  29. super().__init__(*inputs, **kwargs)
  30. def _init_weights(self, module):
  31. """Initialize the weights."""
  32. if isinstance(module, (nn.Linear)):
  33. # Slightly different from the TF version which uses truncated_normal for initialization
  34. # cf https://github.com/pytorch/pytorch/pull/5617
  35. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  36. if module.bias is not None:
  37. module.bias.data.zero_()
  38. elif isinstance(module, nn.Embedding):
  39. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  40. if module.padding_idx is not None:
  41. module.weight.data[module.padding_idx].zero_()
  42. elif isinstance(module, LayerNorm):
  43. module.bias.data.zero_()
  44. module.weight.data.fill_(1.0)
  45. def _set_gradient_checkpointing(self, module, value=False):
  46. if isinstance(module, BloomModel):
  47. module.gradient_checkpointing = value
  48. BLOOM_START_DOCSTRING = r"""
  49. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  50. library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
  51. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  52. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  53. and behavior.
  54. Parameters:
  55. config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
  56. Initializing with a config file does not load the weights associated with the model, only the
  57. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  58. """
  59. BLOOM_INPUTS_DOCSTRING = r"""
  60. Args:
  61. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  62. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  63. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  64. sequence tokens in the vocabulary.
  65. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  66. `input_ids`.
  67. Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  68. [`PreTrainedTokenizer.__call__`] for details.
  69. [What are input IDs?](../glossary#input-ids)
  70. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
  71. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  72. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  73. their past given to this model should not be passed as `input_ids` as they have already been computed.
  74. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  75. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  76. - 1 for tokens that are **not masked**,
  77. - 0 for tokens that are **masked**.
  78. [What are attention masks?](../glossary#attention-mask)
  79. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  80. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  81. config.max_position_embeddings - 1]`.
  82. [What are position IDs?](../glossary#position-ids)
  83. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  84. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  85. - 1 indicates the head is **not masked**,
  86. - 0 indicates the head is **masked**.
  87. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  88. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  89. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  90. model's internal embedding lookup matrix.
  91. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  92. `past_key_values`).
  93. use_cache (`bool`, *optional*):
  94. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  95. `past_key_values`).
  96. output_attentions (`bool`, *optional*):
  97. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  98. tensors for more detail.
  99. output_hidden_states (`bool`, *optional*):
  100. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  101. more detail.
  102. return_dict (`bool`, *optional*):
  103. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  104. """
  105. @add_start_docstrings(
  106. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  107. BLOOM_START_DOCSTRING,
  108. )
  109. class BloomModel(BloomPreTrainedModel):
  110. def __init__(self, config):
  111. super().__init__(config)
  112. self.embed_dim = config.hidden_size
  113. self.n_head = config.n_head
  114. # Embedding + LN Embedding
  115. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  116. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  117. # Transformer blocks
  118. self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
  119. # Final Layer Norm
  120. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  121. self.gradient_checkpointing = False
  122. # Initialize weights and apply final processing
  123. self.post_init()
  124. def get_input_embeddings(self):
  125. return self.word_embeddings
  126. def set_input_embeddings(self, new_embeddings):
  127. self.word_embeddings = new_embeddings
  128. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  129. @add_code_sample_docstrings(
  130. processor_class=_TOKENIZER_FOR_DOC,
  131. checkpoint=_CHECKPOINT_FOR_DOC,
  132. output_type=BaseModelOutputWithPastAndCrossAttentions,
  133. config_class=_CONFIG_FOR_DOC,
  134. )
  135. def forward(
  136. self,
  137. input_ids=None,
  138. past_key_values=None,
  139. attention_mask=None,
  140. position_ids=None,
  141. head_mask=None,
  142. inputs_embeds=None,
  143. use_cache=None,
  144. output_attentions=None,
  145. output_hidden_states=None,
  146. return_dict=None,
  147. ):
  148. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  149. output_hidden_states = (
  150. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  151. )
  152. use_cache = use_cache if use_cache is not None else self.config.use_cache
  153. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  154. if input_ids is not None and inputs_embeds is not None:
  155. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  156. elif input_ids is not None:
  157. input_shape = input_ids.size()
  158. input_ids = input_ids.view(-1, input_shape[-1])
  159. elif inputs_embeds is not None:
  160. input_shape = inputs_embeds.size()[:-1]
  161. else:
  162. raise ValueError("You have to specify either input_ids or inputs_embeds")
  163. if past_key_values is None:
  164. past_key_values = tuple([None] * len(self.h))
  165. # Prepare head mask if needed
  166. # 1.0 in head_mask indicate we keep the head
  167. # attention_probs has shape bsz x n_head x N x N
  168. # head_mask has shape n_layer x batch x n_head x N x N
  169. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  170. if inputs_embeds is None:
  171. inputs_embeds = self.word_embeddings(input_ids)
  172. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  173. output_shape = input_shape + (hidden_states.size(-1),)
  174. presents = () if use_cache else None
  175. all_self_attentions = () if output_attentions else None
  176. all_hidden_states = () if output_hidden_states else None
  177. # Compute alibi tensor: check build_alibi_tensor documentation
  178. current_sequence_length = hidden_states.shape[1]
  179. if past_key_values[0] is not None:
  180. current_sequence_length += past_key_values[0][0].shape[1]
  181. alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
  182. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  183. if output_hidden_states:
  184. all_hidden_states = all_hidden_states + (hidden_states,)
  185. if self.gradient_checkpointing and self.training:
  186. if use_cache:
  187. logger.warning(
  188. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  189. )
  190. use_cache = False
  191. def create_custom_forward(module):
  192. def custom_forward(*inputs):
  193. # None for past_key_value
  194. return module(*inputs, use_cache, output_attentions, alibi)
  195. return custom_forward
  196. outputs = torch.utils.checkpoint.checkpoint(
  197. create_custom_forward(block),
  198. hidden_states,
  199. None,
  200. attention_mask,
  201. head_mask[i],
  202. )
  203. else:
  204. outputs = block(
  205. hidden_states,
  206. layer_past=layer_past,
  207. attention_mask=attention_mask,
  208. head_mask=head_mask[i],
  209. use_cache=use_cache,
  210. output_attentions=output_attentions,
  211. alibi=alibi,
  212. )
  213. hidden_states = outputs[0]
  214. if use_cache is True:
  215. presents = presents + (outputs[1],)
  216. if output_attentions:
  217. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  218. # Add last hidden state
  219. hidden_states = self.ln_f(hidden_states)
  220. if output_hidden_states:
  221. all_hidden_states = all_hidden_states + (hidden_states,)
  222. hidden_states = hidden_states.view(output_shape)
  223. if not return_dict:
  224. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
  225. return BaseModelOutputWithPastAndCrossAttentions(
  226. last_hidden_state=hidden_states,
  227. past_key_values=presents,
  228. hidden_states=all_hidden_states,
  229. attentions=all_self_attentions,
  230. )
  231. @add_start_docstrings(
  232. """
  233. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  234. embeddings).
  235. """,
  236. BLOOM_START_DOCSTRING,
  237. )
  238. class BloomForCausalLM(BloomPreTrainedModel):
  239. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  240. def __init__(self, config):
  241. super().__init__(config)
  242. self.transformer = BloomModel(config)
  243. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  244. # Initialize weights and apply final processing
  245. self.post_init()
  246. def get_output_embeddings(self):
  247. return self.lm_head
  248. def set_output_embeddings(self, new_embeddings):
  249. self.lm_head = new_embeddings
  250. def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
  251. # only last token for inputs_ids if past is defined in kwargs
  252. if past:
  253. input_ids = input_ids[:, -1].unsqueeze(-1)
  254. attention_mask = kwargs.get("attention_mask", None)
  255. position_ids = kwargs.get("position_ids", None)
  256. if attention_mask is not None and position_ids is None:
  257. # create position_ids on the fly for batch generation
  258. position_ids = attention_mask.long().cumsum(-1) - 1
  259. position_ids.masked_fill_(attention_mask == 0, 1)
  260. if past:
  261. position_ids = position_ids[:, -1].unsqueeze(-1)
  262. else:
  263. position_ids = None
  264. return {
  265. "input_ids": input_ids,
  266. "past_key_values": past,
  267. "use_cache": kwargs.get("use_cache"),
  268. "position_ids": position_ids,
  269. "attention_mask": attention_mask,
  270. }
  271. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  272. @add_code_sample_docstrings(
  273. processor_class=_TOKENIZER_FOR_DOC,
  274. checkpoint=_CHECKPOINT_FOR_DOC,
  275. output_type=CausalLMOutputWithCrossAttentions,
  276. config_class=_CONFIG_FOR_DOC,
  277. )
  278. def forward(
  279. self,
  280. input_ids=None,
  281. past_key_values=None,
  282. attention_mask=None,
  283. position_ids=None,
  284. head_mask=None,
  285. inputs_embeds=None,
  286. labels=None,
  287. use_cache=None,
  288. output_attentions=None,
  289. output_hidden_states=None,
  290. return_dict=None,
  291. ):
  292. r"""
  293. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  294. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  295. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  296. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  297. """
  298. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  299. transformer_outputs = self.transformer(
  300. input_ids,
  301. past_key_values=past_key_values,
  302. attention_mask=attention_mask,
  303. position_ids=position_ids,
  304. head_mask=head_mask,
  305. inputs_embeds=inputs_embeds,
  306. use_cache=use_cache,
  307. output_attentions=output_attentions,
  308. output_hidden_states=output_hidden_states,
  309. return_dict=return_dict,
  310. )
  311. hidden_states = transformer_outputs[0]
  312. lm_logits = self.lm_head(hidden_states)
  313. loss = None
  314. if labels is not None:
  315. # Shift so that tokens < n predict n
  316. shift_logits = lm_logits[..., :-1, :].contiguous()
  317. shift_labels = labels[..., 1:].contiguous()
  318. # Flatten the tokens
  319. loss_fct = CrossEntropyLoss()
  320. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  321. if not return_dict:
  322. output = (lm_logits,) + transformer_outputs[1:]
  323. return ((loss,) + output) if loss is not None else output
  324. return CausalLMOutputWithCrossAttentions(
  325. loss=loss,
  326. logits=lm_logits,
  327. past_key_values=transformer_outputs.past_key_values,
  328. hidden_states=transformer_outputs.hidden_states,
  329. attentions=transformer_outputs.attentions,
  330. )
  331. @staticmethod
  332. def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
  333. """
  334. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  335. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  336. beam_idx at every generation step.
  337. """
  338. return tuple(
  339. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  340. for layer_past in past
  341. )