model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. """
  2. PyTorch BLOOM model that implements several memory-efficient modes.
  3. Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
  4. See commit history for authorship.
  5. """
  6. from typing import Tuple
  7. import torch
  8. import torch.utils.checkpoint
  9. from hivemind import use_hivemind_log_handler
  10. from torch import nn
  11. from torch.nn import CrossEntropyLoss, LayerNorm
  12. from transformers.file_utils import (
  13. add_code_sample_docstrings,
  14. add_start_docstrings,
  15. add_start_docstrings_to_model_forward,
  16. )
  17. from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
  18. from transformers.modeling_utils import PreTrainedModel
  19. from transformers.models.bloom.configuration_bloom import BloomConfig as _VanillaBloomConfig
  20. from transformers.utils import logging
  21. from src.bloom.block import BloomBlock
  22. from src.bloom.ops import build_alibi_tensor
  23. use_hivemind_log_handler("in_root_logger")
  24. logger = logging.get_logger(__name__)
  25. _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
  26. _CONFIG_FOR_DOC = "MemoryEfficientBloomConfig"
  27. _TOKENIZER_FOR_DOC = "BloomTokenizer"
  28. class DistributedBloomConfig(_VanillaBloomConfig):
  29. compression: str = "none"
  30. slow_but_exact: bool = False
  31. class BloomPreTrainedModel(PreTrainedModel):
  32. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  33. """
  34. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  35. models.
  36. """
  37. config_class = DistributedBloomConfig
  38. base_model_prefix = "transformer"
  39. supports_gradient_checkpointing = True
  40. _no_split_modules = ["BloomBlock"]
  41. def __init__(self, *inputs, **kwargs):
  42. super().__init__(*inputs, **kwargs)
  43. def _init_weights(self, module):
  44. """Initialize the weights."""
  45. if isinstance(module, (nn.Linear)):
  46. # Slightly different from the TF version which uses truncated_normal for initialization
  47. # cf https://github.com/pytorch/pytorch/pull/5617
  48. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  49. if module.bias is not None:
  50. module.bias.data.zero_()
  51. elif isinstance(module, nn.Embedding):
  52. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  53. if module.padding_idx is not None:
  54. module.weight.data[module.padding_idx].zero_()
  55. elif isinstance(module, LayerNorm):
  56. module.bias.data.zero_()
  57. module.weight.data.fill_(1.0)
  58. def _set_gradient_checkpointing(self, module, value=False):
  59. if isinstance(module, BloomModel):
  60. module.gradient_checkpointing = value
  61. BLOOM_START_DOCSTRING = r"""
  62. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  63. library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
  64. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  65. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  66. and behavior.
  67. Parameters:
  68. config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
  69. Initializing with a config file does not load the weights associated with the model, only the
  70. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  71. """
  72. BLOOM_INPUTS_DOCSTRING = r"""
  73. Args:
  74. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  75. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
  76. `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
  77. sequence tokens in the vocabulary.
  78. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  79. `input_ids`.
  80. Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  81. [`PreTrainedTokenizer.__call__`] for details.
  82. [What are input IDs?](../glossary#input-ids)
  83. past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
  84. Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
  85. `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
  86. their past given to this model should not be passed as `input_ids` as they have already been computed.
  87. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  88. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  89. - 1 for tokens that are **not masked**,
  90. - 0 for tokens that are **masked**.
  91. [What are attention masks?](../glossary#attention-mask)
  92. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  93. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  94. config.max_position_embeddings - 1]`.
  95. [What are position IDs?](../glossary#position-ids)
  96. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  97. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  98. - 1 indicates the head is **not masked**,
  99. - 0 indicates the head is **masked**.
  100. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  101. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  102. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  103. model's internal embedding lookup matrix.
  104. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  105. `past_key_values`).
  106. use_cache (`bool`, *optional*):
  107. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  108. `past_key_values`).
  109. output_attentions (`bool`, *optional*):
  110. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  111. tensors for more detail.
  112. output_hidden_states (`bool`, *optional*):
  113. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  114. more detail.
  115. return_dict (`bool`, *optional*):
  116. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  117. """
  118. @add_start_docstrings(
  119. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  120. BLOOM_START_DOCSTRING,
  121. )
  122. class BloomModel(BloomPreTrainedModel):
  123. def __init__(self, config):
  124. super().__init__(config)
  125. assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
  126. self.embed_dim = config.hidden_size
  127. self.n_head = config.n_head
  128. # Embedding + LN Embedding
  129. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  130. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  131. # Transformer blocks
  132. self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
  133. # Final Layer Norm
  134. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  135. self.gradient_checkpointing = False
  136. # Initialize weights and apply final processing
  137. self.post_init()
  138. def get_input_embeddings(self):
  139. return self.word_embeddings
  140. def set_input_embeddings(self, new_embeddings):
  141. self.word_embeddings = new_embeddings
  142. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  143. @add_code_sample_docstrings(
  144. processor_class=_TOKENIZER_FOR_DOC,
  145. checkpoint=_CHECKPOINT_FOR_DOC,
  146. output_type=BaseModelOutputWithPastAndCrossAttentions,
  147. config_class=_CONFIG_FOR_DOC,
  148. )
  149. def forward(
  150. self,
  151. input_ids=None,
  152. past_key_values=None,
  153. attention_mask=None,
  154. position_ids=None,
  155. head_mask=None,
  156. inputs_embeds=None,
  157. use_cache=None,
  158. output_attentions=None,
  159. output_hidden_states=None,
  160. return_dict=None,
  161. ):
  162. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  163. output_hidden_states = (
  164. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  165. )
  166. use_cache = use_cache if use_cache is not None else self.config.use_cache
  167. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  168. if input_ids is not None and inputs_embeds is not None:
  169. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  170. elif input_ids is not None:
  171. input_shape = input_ids.size()
  172. input_ids = input_ids.view(-1, input_shape[-1])
  173. elif inputs_embeds is not None:
  174. input_shape = inputs_embeds.size()[:-1]
  175. else:
  176. raise ValueError("You have to specify either input_ids or inputs_embeds")
  177. if past_key_values is None:
  178. past_key_values = tuple([None] * len(self.h))
  179. # Prepare head mask if needed
  180. # 1.0 in head_mask indicate we keep the head
  181. # attention_probs has shape bsz x n_head x N x N
  182. # head_mask has shape n_layer x batch x n_head x N x N
  183. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  184. if inputs_embeds is None:
  185. inputs_embeds = self.word_embeddings(input_ids)
  186. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  187. output_shape = input_shape + (hidden_states.size(-1),)
  188. presents = () if use_cache else None
  189. all_self_attentions = () if output_attentions else None
  190. all_hidden_states = () if output_hidden_states else None
  191. # Compute alibi tensor: check build_alibi_tensor documentation
  192. current_sequence_length = hidden_states.shape[1]
  193. if past_key_values[0] is not None:
  194. current_sequence_length += past_key_values[0][0].shape[1]
  195. alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
  196. for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
  197. if output_hidden_states:
  198. all_hidden_states = all_hidden_states + (hidden_states,)
  199. if self.gradient_checkpointing and self.training:
  200. if use_cache:
  201. logger.warning(
  202. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  203. )
  204. use_cache = False
  205. def create_custom_forward(module):
  206. def custom_forward(*inputs):
  207. # None for past_key_value
  208. return module(*inputs, use_cache, output_attentions, alibi)
  209. return custom_forward
  210. outputs = torch.utils.checkpoint.checkpoint(
  211. create_custom_forward(block),
  212. hidden_states,
  213. None,
  214. attention_mask,
  215. head_mask[i],
  216. )
  217. else:
  218. outputs = block(
  219. hidden_states,
  220. layer_past=layer_past,
  221. attention_mask=attention_mask,
  222. head_mask=head_mask[i],
  223. use_cache=use_cache,
  224. output_attentions=output_attentions,
  225. alibi=alibi,
  226. )
  227. hidden_states = outputs[0]
  228. if use_cache is True:
  229. presents = presents + (outputs[1],)
  230. if output_attentions:
  231. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  232. # Add last hidden state
  233. hidden_states = self.ln_f(hidden_states)
  234. if output_hidden_states:
  235. all_hidden_states = all_hidden_states + (hidden_states,)
  236. hidden_states = hidden_states.view(output_shape)
  237. if not return_dict:
  238. return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
  239. return BaseModelOutputWithPastAndCrossAttentions(
  240. last_hidden_state=hidden_states,
  241. past_key_values=presents,
  242. hidden_states=all_hidden_states,
  243. attentions=all_self_attentions,
  244. )
  245. @add_start_docstrings(
  246. """
  247. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  248. embeddings).
  249. """,
  250. BLOOM_START_DOCSTRING,
  251. )
  252. class BloomForCausalLM(BloomPreTrainedModel):
  253. _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
  254. def __init__(self, config):
  255. super().__init__(config)
  256. self.transformer = BloomModel(config)
  257. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  258. # Initialize weights and apply final processing
  259. self.post_init()
  260. def get_output_embeddings(self):
  261. return self.lm_head
  262. def set_output_embeddings(self, new_embeddings):
  263. self.lm_head = new_embeddings
  264. def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
  265. # only last token for inputs_ids if past is defined in kwargs
  266. if past:
  267. input_ids = input_ids[:, -1].unsqueeze(-1)
  268. attention_mask = kwargs.get("attention_mask", None)
  269. position_ids = kwargs.get("position_ids", None)
  270. if attention_mask is not None and position_ids is None:
  271. # create position_ids on the fly for batch generation
  272. position_ids = attention_mask.long().cumsum(-1) - 1
  273. position_ids.masked_fill_(attention_mask == 0, 1)
  274. if past:
  275. position_ids = position_ids[:, -1].unsqueeze(-1)
  276. else:
  277. position_ids = None
  278. return {
  279. "input_ids": input_ids,
  280. "past_key_values": past,
  281. "use_cache": kwargs.get("use_cache"),
  282. "position_ids": position_ids,
  283. "attention_mask": attention_mask,
  284. }
  285. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  286. @add_code_sample_docstrings(
  287. processor_class=_TOKENIZER_FOR_DOC,
  288. checkpoint=_CHECKPOINT_FOR_DOC,
  289. output_type=CausalLMOutputWithCrossAttentions,
  290. config_class=_CONFIG_FOR_DOC,
  291. )
  292. def forward(
  293. self,
  294. input_ids=None,
  295. past_key_values=None,
  296. attention_mask=None,
  297. position_ids=None,
  298. head_mask=None,
  299. inputs_embeds=None,
  300. labels=None,
  301. use_cache=None,
  302. output_attentions=None,
  303. output_hidden_states=None,
  304. return_dict=None,
  305. ):
  306. r"""
  307. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  308. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  309. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  310. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  311. """
  312. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  313. transformer_outputs = self.transformer(
  314. input_ids,
  315. past_key_values=past_key_values,
  316. attention_mask=attention_mask,
  317. position_ids=position_ids,
  318. head_mask=head_mask,
  319. inputs_embeds=inputs_embeds,
  320. use_cache=use_cache,
  321. output_attentions=output_attentions,
  322. output_hidden_states=output_hidden_states,
  323. return_dict=return_dict,
  324. )
  325. hidden_states = transformer_outputs[0]
  326. lm_logits = self.lm_head(hidden_states)
  327. loss = None
  328. if labels is not None:
  329. # Shift so that tokens < n predict n
  330. shift_logits = lm_logits[..., :-1, :].contiguous()
  331. shift_labels = labels[..., 1:].contiguous()
  332. # Flatten the tokens
  333. loss_fct = CrossEntropyLoss()
  334. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  335. if not return_dict:
  336. output = (lm_logits,) + transformer_outputs[1:]
  337. return ((loss,) + output) if loss is not None else output
  338. return CausalLMOutputWithCrossAttentions(
  339. loss=loss,
  340. logits=lm_logits,
  341. past_key_values=transformer_outputs.past_key_values,
  342. hidden_states=transformer_outputs.hidden_states,
  343. attentions=transformer_outputs.attentions,
  344. )
  345. @staticmethod
  346. def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
  347. """
  348. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  349. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  350. beam_idx at every generation step.
  351. """
  352. return tuple(
  353. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  354. for layer_past in past
  355. )