remote_generation.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. from typing import List, Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
  5. from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
  6. class RemoteGenerationMixin:
  7. """
  8. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
  9. The class exposes can be used for:
  10. - *greedy decoding*.
  11. - *multinomial sampling*.
  12. This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
  13. """
  14. @torch.no_grad()
  15. def generate(
  16. self,
  17. inputs: Optional[torch.Tensor] = None,
  18. do_sample: Optional[bool] = None,
  19. temperature: float = 1.0,
  20. top_k: Optional[int] = None,
  21. top_p: Optional[float] = None,
  22. bos_token_id: Optional[int] = None,
  23. eos_token_id: Optional[int] = None,
  24. pad_token_id: Optional[int] = None,
  25. max_length: Optional[int] = None,
  26. max_new_tokens: Optional[int] = None,
  27. decoding_algorithm: Optional[DecodingAlgorithm] = None,
  28. provided_constraints: List[ABCBloomConstraint] = [],
  29. **model_kwargs,
  30. ) -> torch.LongTensor:
  31. """
  32. Generates sequences of token ids for models with a language modeling head.
  33. :param inputs: The input tokens to the model.
  34. :param do_sample: Whether to sample from the model predictions or take the argmax.
  35. :param temperature: The temperature to use for sampling.
  36. :param top_k: The number of results to return.
  37. :param top_p: The cumulative probability of results to return.
  38. :param bos_token_id: The id of the beginning of sentence token.
  39. :param eos_token_id: The id of the end of sentence token.
  40. :param pad_token_id: The id of the padding token.
  41. :param max_new_tokens: The maximum number of tokens to generate.
  42. :param decoding_algorithm: The decoding algorithm to use.
  43. :param provided_constraints: A list of constraints to use.
  44. :param model_kwargs: Additional arguments to pass to the model.
  45. """
  46. assert (
  47. model_kwargs.get("logits_processor", None) is None
  48. ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
  49. assert (
  50. model_kwargs.get("logits_wrapper", None) is None
  51. ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
  52. assert (
  53. model_kwargs.get("stopping_criteria", None) is None
  54. ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
  55. bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
  56. pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
  57. eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
  58. if max_length is not None and max_new_tokens is None:
  59. max_new_tokens = max_length - inputs.size(1)
  60. assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
  61. if inputs is None:
  62. assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
  63. inputs = torch.tensor([[bos_token_id]])
  64. if decoding_algorithm is None:
  65. if do_sample:
  66. decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
  67. else:
  68. decoding_algorithm = GreedyAlgorithm()
  69. constraints = self._get_constraints(
  70. inputs=inputs,
  71. eos_token_id=eos_token_id,
  72. pad_token_id=pad_token_id,
  73. max_new_tokens=max_new_tokens,
  74. provided_constraints=provided_constraints,
  75. )
  76. with self.transformer.h.inference_session() as sess:
  77. outputs = []
  78. if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
  79. outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
  80. else:
  81. outputs += [inputs]
  82. last_token_id = None
  83. seq_idx = outputs[0].size(1)
  84. hypo_ids = torch.arange(outputs[0].size(0))
  85. while True:
  86. embs = self.transformer.word_embeddings(outputs[-1])
  87. embs = self.transformer.word_embeddings_layernorm(embs)
  88. hidden_state = sess.step(embs)[:, -1]
  89. hidden_state = self.transformer.ln_f(hidden_state)
  90. lm_logits = self.lm_head(hidden_state)
  91. for constraint in constraints:
  92. lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
  93. last_token_id, hypo_ids = decoding_algorithm(lm_logits)
  94. if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
  95. pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
  96. last_token_id = (~pad_token_mask) * inputs[
  97. :, seq_idx : seq_idx + 1
  98. ] + pad_token_mask * last_token_id
  99. if torch.all(last_token_id == eos_token_id):
  100. break
  101. outputs.append(last_token_id)
  102. seq_idx += 1
  103. return torch.cat(outputs, dim=-1)
  104. def greedy_search(
  105. self,
  106. input_ids: torch.LongTensor,
  107. max_length: Optional[int] = None,
  108. pad_token_id: Optional[int] = None,
  109. eos_token_id: Optional[int] = None,
  110. provided_constraints: List[ABCBloomConstraint] = [],
  111. **model_kwargs,
  112. ) -> torch.LongTensor:
  113. """
  114. Generates sequences of token ids for models with a language modeling head. Uses greedy search.
  115. :param input_ids: The input tokens to the model.
  116. :param max_length: The maximum length of the sequence to generate.
  117. :param pad_token_id: The id of the padding token.
  118. :param eos_token_id: The id of the end of sentence token.
  119. :param provided_constraints: A list of constraints to use.
  120. """
  121. return self.generate(
  122. inputs=input_ids,
  123. max_new_tokens=max_length,
  124. pad_token_id=pad_token_id,
  125. eos_token_id=eos_token_id,
  126. decoding_algorithm=GreedyAlgorithm(),
  127. provided_constraints=provided_constraints,
  128. **model_kwargs,
  129. )
  130. def sample(
  131. self,
  132. input_ids: torch.LongTensor,
  133. temperature: float = 1.0,
  134. top_k: Optional[int] = None,
  135. top_p: Optional[float] = None,
  136. max_length: Optional[int] = None,
  137. pad_token_id: Optional[int] = None,
  138. eos_token_id: Optional[int] = None,
  139. provided_constraints: List[ABCBloomConstraint] = [],
  140. **model_kwargs,
  141. ) -> torch.LongTensor:
  142. """
  143. Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
  144. :param: input_ids: The input tokens to the model.
  145. :param: temperature: The temperature to use for sampling.
  146. :param: top_k: The number of samples to use for top_k sampling.
  147. :param: top_p: The probability of using top_p sampling.
  148. :param: max_length: The maximum length of the sequence to generate.
  149. :param: pad_token_id: The id of the padding token.
  150. :param: eos_token_id: The id of the end of sentence token.
  151. :param: provided_constraints: A list of constraints to use.
  152. :param: model_kwargs: Additional kwargs to pass to the model.
  153. """
  154. return self.generate(
  155. inputs=input_ids,
  156. max_new_tokens=max_length,
  157. pad_token_id=pad_token_id,
  158. eos_token_id=eos_token_id,
  159. decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
  160. provided_constraints=provided_constraints,
  161. **model_kwargs,
  162. )
  163. def beam_search(
  164. self,
  165. input_ids: torch.LongTensor,
  166. max_length: Optional[int] = None,
  167. pad_token_id: Optional[int] = None,
  168. eos_token_id: Optional[int] = None,
  169. provided_constraints: List[ABCBloomConstraint] = [],
  170. **model_kwargs,
  171. ) -> torch.LongTensor:
  172. raise NotImplementedError
  173. def beam_sample(
  174. self,
  175. input_ids: torch.LongTensor,
  176. max_length: Optional[int] = None,
  177. pad_token_id: Optional[int] = None,
  178. eos_token_id: Optional[int] = None,
  179. provided_constraints: List[ABCBloomConstraint] = [],
  180. **model_kwargs,
  181. ) -> torch.LongTensor:
  182. raise NotImplementedError
  183. def group_beam_search(
  184. self,
  185. input_ids: torch.LongTensor,
  186. max_length: Optional[int] = None,
  187. pad_token_id: Optional[int] = None,
  188. eos_token_id: Optional[int] = None,
  189. provided_constraints: List[ABCBloomConstraint] = [],
  190. **model_kwargs,
  191. ) -> torch.LongTensor:
  192. raise NotImplementedError
  193. def _choose_sample_algorithm(
  194. self,
  195. temperature: float = 1.0,
  196. top_k: Optional[int] = None,
  197. top_p: Optional[float] = None,
  198. ) -> DecodingAlgorithm:
  199. if (top_k is not None) and (top_p is not None):
  200. raise ValueError("You have to provide only top_k or top_p for sampling")
  201. if top_k:
  202. return TopKAlgorithm(top_k, temperature)
  203. elif top_p:
  204. return NucleusAlgorithm(top_p, temperature)
  205. def _get_constraints(
  206. self,
  207. inputs: Optional[torch.Tensor] = None,
  208. eos_token_id: Optional[int] = None,
  209. pad_token_id: Optional[int] = None,
  210. max_new_tokens: Optional[int] = None,
  211. provided_constraints: List[ABCBloomConstraint] = [],
  212. ) -> List[ABCBloomConstraint]:
  213. constraints = []
  214. constraints.extend(provided_constraints)
  215. if max_new_tokens is not None:
  216. constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
  217. constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
  218. return constraints