remote_generation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from typing import List, Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
  5. from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
  6. class RemoteGenerationMixin:
  7. """
  8. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
  9. The class exposes can be used for:
  10. - *greedy decoding*.
  11. - *multinomial sampling*.
  12. This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
  13. """
  14. @torch.no_grad()
  15. def generate(
  16. self,
  17. inputs: Optional[torch.Tensor] = None,
  18. do_sample: Optional[bool] = None,
  19. temperature: float = 1.0,
  20. top_k: Optional[int] = None,
  21. top_p: Optional[float] = None,
  22. bos_token_id: Optional[int] = None,
  23. eos_token_id: Optional[int] = None,
  24. pad_token_id: Optional[int] = None,
  25. max_length: Optional[int] = None,
  26. max_new_tokens: Optional[int] = None,
  27. decoding_algorithm: Optional[DecodingAlgorithm] = None,
  28. provided_constraints: List[ABCBloomConstraint] = [],
  29. **model_kwargs,
  30. ) -> torch.LongTensor:
  31. """
  32. Generates sequences of token ids for models with a language modeling head.
  33. :param inputs: The input tokens to the model.
  34. :param do_sample: Whether to sample from the model predictions or take the argmax.
  35. :param temperature: The temperature to use for sampling.
  36. :param top_k: The number of results to return.
  37. :param top_p: The cumulative probability of results to return.
  38. :param bos_token_id: The id of the beginning of sentence token.
  39. :param eos_token_id: The id of the end of sentence token.
  40. :param pad_token_id: The id of the padding token.
  41. :param max_new_tokens: The maximum number of tokens to generate.
  42. :param decoding_algorithm: The decoding algorithm to use.
  43. :param provided_constraints: A list of constraints to use.
  44. :param model_kwargs: Additional arguments to pass to the model.
  45. """
  46. assert (
  47. model_kwargs.get("logits_processor", None) is None
  48. ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
  49. assert (
  50. model_kwargs.get("logits_wrapper", None) is None
  51. ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
  52. assert (
  53. model_kwargs.get("stopping_criteria", None) is None
  54. ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
  55. if inputs is not None:
  56. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
  57. prefix_length = 0 if inputs is None else inputs.size(1)
  58. prefix_length += self.config.pre_seq_len
  59. bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
  60. pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
  61. eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
  62. assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
  63. if max_length is not None and max_new_tokens is None:
  64. max_new_tokens = max_length - prefix_length
  65. assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
  66. elif max_length is None and max_new_tokens is not None:
  67. max_length = prefix_length + max_new_tokens
  68. if inputs is None:
  69. assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
  70. inputs = torch.tensor([[bos_token_id]])
  71. if decoding_algorithm is None:
  72. if do_sample:
  73. decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
  74. else:
  75. decoding_algorithm = GreedyAlgorithm()
  76. constraints = self._get_constraints(
  77. inputs=inputs,
  78. eos_token_id=eos_token_id,
  79. pad_token_id=pad_token_id,
  80. max_new_tokens=max_new_tokens,
  81. provided_constraints=provided_constraints,
  82. )
  83. with self.transformer.h.inference_session(max_length=max_length) as sess:
  84. outputs = []
  85. if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
  86. outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
  87. else:
  88. outputs += [inputs]
  89. last_token_id = None
  90. seq_idx = outputs[0].size(1)
  91. hypo_ids = torch.arange(outputs[0].size(0))
  92. while True:
  93. embs = self.transformer.word_embeddings(outputs[-1])
  94. intermediate_prompts = None
  95. if self.config.pre_seq_len > 0 and len(outputs) == 1:
  96. prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
  97. embs = torch.cat([prompts, embs], dim=1)
  98. embs = self.transformer.word_embeddings_layernorm(embs)
  99. hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
  100. hidden_state = self.transformer.ln_f(hidden_state)
  101. lm_logits = self.lm_head(hidden_state)
  102. for constraint in constraints:
  103. lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
  104. last_token_id, hypo_ids = decoding_algorithm(lm_logits)
  105. if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
  106. pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
  107. last_token_id = (~pad_token_mask) * inputs[
  108. :, seq_idx : seq_idx + 1
  109. ] + pad_token_mask * last_token_id
  110. if torch.all(last_token_id == eos_token_id):
  111. break
  112. outputs.append(last_token_id)
  113. seq_idx += 1
  114. return torch.cat(outputs, dim=-1)
  115. def greedy_search(
  116. self,
  117. input_ids: torch.LongTensor,
  118. max_length: Optional[int] = None,
  119. pad_token_id: Optional[int] = None,
  120. eos_token_id: Optional[int] = None,
  121. provided_constraints: List[ABCBloomConstraint] = [],
  122. **model_kwargs,
  123. ) -> torch.LongTensor:
  124. """
  125. Generates sequences of token ids for models with a language modeling head. Uses greedy search.
  126. :param input_ids: The input tokens to the model.
  127. :param max_length: The maximum length of the sequence to generate.
  128. :param pad_token_id: The id of the padding token.
  129. :param eos_token_id: The id of the end of sentence token.
  130. :param provided_constraints: A list of constraints to use.
  131. """
  132. return self.generate(
  133. inputs=input_ids,
  134. max_new_tokens=max_length,
  135. pad_token_id=pad_token_id,
  136. eos_token_id=eos_token_id,
  137. decoding_algorithm=GreedyAlgorithm(),
  138. provided_constraints=provided_constraints,
  139. **model_kwargs,
  140. )
  141. def sample(
  142. self,
  143. input_ids: torch.LongTensor,
  144. temperature: float = 1.0,
  145. top_k: Optional[int] = None,
  146. top_p: Optional[float] = None,
  147. max_length: Optional[int] = None,
  148. pad_token_id: Optional[int] = None,
  149. eos_token_id: Optional[int] = None,
  150. provided_constraints: List[ABCBloomConstraint] = [],
  151. **model_kwargs,
  152. ) -> torch.LongTensor:
  153. """
  154. Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
  155. :param: input_ids: The input tokens to the model.
  156. :param: temperature: The temperature to use for sampling.
  157. :param: top_k: The number of samples to use for top_k sampling.
  158. :param: top_p: The probability of using top_p sampling.
  159. :param: max_length: The maximum length of the sequence to generate.
  160. :param: pad_token_id: The id of the padding token.
  161. :param: eos_token_id: The id of the end of sentence token.
  162. :param: provided_constraints: A list of constraints to use.
  163. :param: model_kwargs: Additional kwargs to pass to the model.
  164. """
  165. return self.generate(
  166. inputs=input_ids,
  167. max_new_tokens=max_length,
  168. pad_token_id=pad_token_id,
  169. eos_token_id=eos_token_id,
  170. decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
  171. provided_constraints=provided_constraints,
  172. **model_kwargs,
  173. )
  174. def beam_search(
  175. self,
  176. input_ids: torch.LongTensor,
  177. max_length: Optional[int] = None,
  178. pad_token_id: Optional[int] = None,
  179. eos_token_id: Optional[int] = None,
  180. provided_constraints: List[ABCBloomConstraint] = [],
  181. **model_kwargs,
  182. ) -> torch.LongTensor:
  183. raise NotImplementedError
  184. def beam_sample(
  185. self,
  186. input_ids: torch.LongTensor,
  187. max_length: Optional[int] = None,
  188. pad_token_id: Optional[int] = None,
  189. eos_token_id: Optional[int] = None,
  190. provided_constraints: List[ABCBloomConstraint] = [],
  191. **model_kwargs,
  192. ) -> torch.LongTensor:
  193. raise NotImplementedError
  194. def group_beam_search(
  195. self,
  196. input_ids: torch.LongTensor,
  197. max_length: Optional[int] = None,
  198. pad_token_id: Optional[int] = None,
  199. eos_token_id: Optional[int] = None,
  200. provided_constraints: List[ABCBloomConstraint] = [],
  201. **model_kwargs,
  202. ) -> torch.LongTensor:
  203. raise NotImplementedError
  204. def _choose_sample_algorithm(
  205. self,
  206. temperature: float = 1.0,
  207. top_k: Optional[int] = None,
  208. top_p: Optional[float] = None,
  209. ) -> DecodingAlgorithm:
  210. if (top_k is not None) and (top_p is not None):
  211. raise ValueError("You have to provide only top_k or top_p for sampling")
  212. if top_k:
  213. return TopKAlgorithm(top_k, temperature)
  214. elif top_p:
  215. return NucleusAlgorithm(top_p, temperature)
  216. def _get_constraints(
  217. self,
  218. inputs: Optional[torch.Tensor] = None,
  219. eos_token_id: Optional[int] = None,
  220. pad_token_id: Optional[int] = None,
  221. max_new_tokens: Optional[int] = None,
  222. provided_constraints: List[ABCBloomConstraint] = [],
  223. ) -> List[ABCBloomConstraint]:
  224. constraints = []
  225. constraints.extend(provided_constraints)
  226. if max_new_tokens is not None:
  227. constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
  228. constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
  229. return constraints