remote_generation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import List, Optional
  2. import torch
  3. import torch.nn.functional as F
  4. from src.utils.generation_algorithms import (
  5. BeamSearchAlgorithm,
  6. DecodingAlgorithm,
  7. GreedyAlgorithm,
  8. NucleusAlgorithm,
  9. TopKAlgorithm,
  10. )
  11. from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
  12. class RemoteGenerationMixin:
  13. """
  14. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
  15. The class exposes can be used for:
  16. - *greedy decoding*.
  17. - *multinomial sampling*.
  18. - *beam-search decoding*
  19. This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
  20. """
  21. @torch.no_grad()
  22. def generate(
  23. self,
  24. inputs: Optional[torch.Tensor] = None,
  25. do_sample: Optional[bool] = None,
  26. temperature: float = 1.0,
  27. top_k: Optional[int] = None,
  28. top_p: Optional[float] = None,
  29. num_beams: Optional[int] = 1,
  30. bos_token_id: Optional[int] = None,
  31. eos_token_id: Optional[int] = None,
  32. pad_token_id: Optional[int] = None,
  33. max_length: Optional[int] = None,
  34. max_new_tokens: Optional[int] = None,
  35. decoding_algorithm: Optional[DecodingAlgorithm] = None,
  36. provided_constraints: List[ABCBloomConstraint] = [],
  37. **model_kwargs,
  38. ) -> torch.LongTensor:
  39. """
  40. Generates sequences of token ids for models with a language modeling head.
  41. :param inputs: The input tokens to the model.
  42. :param do_sample: Whether to sample from the model predictions or take the argmax.
  43. :param temperature: The temperature to use for sampling.
  44. :param top_k: The number of results to return.
  45. :param top_p: The cumulative probability of results to return.
  46. :param num_beams: The number of beams to use for beam search.
  47. :param bos_token_id: The id of the beginning of sentence token.
  48. :param eos_token_id: The id of the end of sentence token.
  49. :param pad_token_id: The id of the padding token.
  50. :param max_new_tokens: The maximum number of tokens to generate.
  51. :param decoding_algorithm: The decoding algorithm to use.
  52. :param provided_constraints: A list of constraints to use.
  53. :param model_kwargs: Additional arguments to pass to the model.
  54. """
  55. assert (
  56. model_kwargs.get("logits_processor", None) is None
  57. ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
  58. assert (
  59. model_kwargs.get("logits_wrapper", None) is None
  60. ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
  61. assert (
  62. model_kwargs.get("stopping_criteria", None) is None
  63. ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
  64. if inputs is not None:
  65. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
  66. prefix_length = 0 if inputs is None else inputs.size(1)
  67. prefix_length += self.config.pre_seq_len
  68. bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
  69. pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
  70. eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
  71. assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
  72. if max_length is not None and max_new_tokens is None:
  73. max_new_tokens = max_length - prefix_length
  74. assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
  75. elif max_length is None and max_new_tokens is not None:
  76. max_length = prefix_length + max_new_tokens
  77. if inputs is None:
  78. assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
  79. inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
  80. if decoding_algorithm is None:
  81. if do_sample:
  82. decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
  83. elif num_beams is not None and num_beams > 1:
  84. decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=inputs.size(0))
  85. else:
  86. decoding_algorithm = GreedyAlgorithm()
  87. if num_beams > 1:
  88. inputs = torch.cat([inputs] * num_beams, dim=0)
  89. constraints = self._get_constraints(
  90. inputs=inputs,
  91. eos_token_id=eos_token_id,
  92. pad_token_id=pad_token_id,
  93. provided_constraints=provided_constraints,
  94. )
  95. with self.transformer.h.inference_session(max_length=max_length) as sess:
  96. outputs = []
  97. # Find samples with padded inputs.
  98. # They will be changed before all of the samples have right length.
  99. if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
  100. outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
  101. else:
  102. outputs += [inputs]
  103. last_token_id = None
  104. seq_idx = outputs[0].size(1)
  105. hypo_ids = torch.arange(outputs[0].size(0))
  106. while True:
  107. embs = self.transformer.word_embeddings(outputs[-1])
  108. intermediate_prompts = None
  109. if self.config.pre_seq_len > 0 and len(outputs) == 1:
  110. prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
  111. embs = torch.cat([prompts, embs], dim=1)
  112. embs = self.transformer.word_embeddings_layernorm(embs)
  113. hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
  114. hidden_state = self.transformer.ln_f(hidden_state)
  115. lm_logits = self.lm_head(hidden_state)
  116. for constraint in constraints:
  117. lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
  118. last_token_id, hypo_ids = decoding_algorithm(lm_logits)
  119. # If samples have padded, so changes only them
  120. if seq_idx < inputs.size(1):
  121. pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
  122. last_token_id = (~pad_token_mask) * inputs[
  123. :, seq_idx : seq_idx + 1
  124. ] + pad_token_mask * last_token_id
  125. if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
  126. break
  127. outputs.append(last_token_id)
  128. seq_idx += 1
  129. return torch.cat(outputs, dim=-1)
  130. def greedy_search(
  131. self,
  132. input_ids: torch.LongTensor,
  133. max_length: Optional[int] = None,
  134. pad_token_id: Optional[int] = None,
  135. eos_token_id: Optional[int] = None,
  136. provided_constraints: List[ABCBloomConstraint] = [],
  137. **model_kwargs,
  138. ) -> torch.LongTensor:
  139. """
  140. Generates sequences of token ids for models with a language modeling head. Uses greedy search.
  141. :param input_ids: The input tokens to the model.
  142. :param max_length: The maximum length of the sequence to generate.
  143. :param pad_token_id: The id of the padding token.
  144. :param eos_token_id: The id of the end of sentence token.
  145. :param provided_constraints: A list of constraints to use.
  146. """
  147. return self.generate(
  148. inputs=input_ids,
  149. max_new_tokens=max_length,
  150. pad_token_id=pad_token_id,
  151. eos_token_id=eos_token_id,
  152. decoding_algorithm=GreedyAlgorithm(),
  153. provided_constraints=provided_constraints,
  154. **model_kwargs,
  155. )
  156. def sample(
  157. self,
  158. input_ids: torch.LongTensor,
  159. temperature: float = 1.0,
  160. top_k: Optional[int] = None,
  161. top_p: Optional[float] = None,
  162. max_length: Optional[int] = None,
  163. pad_token_id: Optional[int] = None,
  164. eos_token_id: Optional[int] = None,
  165. provided_constraints: List[ABCBloomConstraint] = [],
  166. **model_kwargs,
  167. ) -> torch.LongTensor:
  168. """
  169. Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
  170. :param: input_ids: The input tokens to the model.
  171. :param: temperature: The temperature to use for sampling.
  172. :param: top_k: The number of samples to use for top_k sampling.
  173. :param: top_p: The probability of using top_p sampling.
  174. :param: max_length: The maximum length of the sequence to generate.
  175. :param: pad_token_id: The id of the padding token.
  176. :param: eos_token_id: The id of the end of sentence token.
  177. :param: provided_constraints: A list of constraints to use.
  178. :param: model_kwargs: Additional kwargs to pass to the model.
  179. """
  180. return self.generate(
  181. inputs=input_ids,
  182. max_new_tokens=max_length,
  183. pad_token_id=pad_token_id,
  184. eos_token_id=eos_token_id,
  185. decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
  186. provided_constraints=provided_constraints,
  187. **model_kwargs,
  188. )
  189. def beam_search(
  190. self,
  191. input_ids: torch.LongTensor,
  192. num_beams: int = 1,
  193. max_length: Optional[int] = None,
  194. pad_token_id: Optional[int] = None,
  195. eos_token_id: Optional[int] = None,
  196. provided_constraints: List[ABCBloomConstraint] = [],
  197. **model_kwargs,
  198. ) -> torch.LongTensor:
  199. """
  200. Generates sequences of token ids for models with a language modeling head. Uses beam search.
  201. :param input_ids: The input tokens to the model.
  202. :param num_beams: The number of beams to use.
  203. :param max_length: The maximum length of the sequence to generate.
  204. :param pad_token_id: The id of the padding token.
  205. :param eos_token_id: The id of the end of sentence token.
  206. :param provided_constraints: A list of constraints to use.
  207. :param: model_kwargs: Additional kwargs to pass to the model.
  208. """
  209. decoding_algorithm = BeamSearchAlgorithm(
  210. num_beams=num_beams,
  211. bath_size=input_ids.size(0),
  212. )
  213. return self.generate(
  214. inputs=input_ids,
  215. num_beams=num_beams,
  216. max_new_tokens=max_length,
  217. pad_token_id=pad_token_id,
  218. eos_token_id=eos_token_id,
  219. decoding_algorithm=decoding_algorithm,
  220. provided_constraints=provided_constraints,
  221. **model_kwargs,
  222. )
  223. def beam_sample(
  224. self,
  225. input_ids: torch.LongTensor,
  226. max_length: Optional[int] = None,
  227. pad_token_id: Optional[int] = None,
  228. eos_token_id: Optional[int] = None,
  229. provided_constraints: List[ABCBloomConstraint] = [],
  230. **model_kwargs,
  231. ) -> torch.LongTensor:
  232. raise NotImplementedError
  233. def group_beam_search(
  234. self,
  235. input_ids: torch.LongTensor,
  236. max_length: Optional[int] = None,
  237. pad_token_id: Optional[int] = None,
  238. eos_token_id: Optional[int] = None,
  239. provided_constraints: List[ABCBloomConstraint] = [],
  240. **model_kwargs,
  241. ) -> torch.LongTensor:
  242. raise NotImplementedError
  243. def _choose_sample_algorithm(
  244. self,
  245. temperature: float = 1.0,
  246. top_k: Optional[int] = None,
  247. top_p: Optional[float] = None,
  248. ) -> DecodingAlgorithm:
  249. if (top_k is not None) and (top_p is not None):
  250. raise ValueError("You have to provide only top_k or top_p for sampling")
  251. if top_k:
  252. return TopKAlgorithm(top_k, temperature)
  253. elif top_p:
  254. return NucleusAlgorithm(top_p, temperature)
  255. def _get_constraints(
  256. self,
  257. inputs: Optional[torch.Tensor] = None,
  258. eos_token_id: Optional[int] = None,
  259. pad_token_id: Optional[int] = None,
  260. provided_constraints: List[ABCBloomConstraint] = [],
  261. ) -> List[ABCBloomConstraint]:
  262. constraints = []
  263. constraints.extend(provided_constraints)
  264. constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
  265. return constraints