remote_generation.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. from typing import List, Optional
  2. import torch
  3. from hivemind.utils.logging import get_logger
  4. from src.utils.generation_algorithms import (
  5. BeamSearchAlgorithm,
  6. DecodingAlgorithm,
  7. GreedyAlgorithm,
  8. NucleusAlgorithm,
  9. TopKAlgorithm,
  10. )
  11. from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
  12. logger = get_logger(__file__)
  13. class RemoteGenerationMixin:
  14. """
  15. A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
  16. The class exposes can be used for:
  17. - *greedy decoding*.
  18. - *multinomial sampling*.
  19. - *beam-search decoding*
  20. This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
  21. """
  22. @torch.no_grad()
  23. def generate(
  24. self,
  25. inputs: Optional[torch.Tensor] = None,
  26. do_sample: Optional[bool] = None,
  27. temperature: float = 1.0,
  28. top_k: Optional[int] = None,
  29. top_p: Optional[float] = None,
  30. num_beams: Optional[int] = 1,
  31. bos_token_id: Optional[int] = None,
  32. eos_token_id: Optional[int] = None,
  33. pad_token_id: Optional[int] = None,
  34. max_length: Optional[int] = None,
  35. max_new_tokens: Optional[int] = None,
  36. decoding_algorithm: Optional[DecodingAlgorithm] = None,
  37. provided_constraints: List[ABCBloomConstraint] = [],
  38. num_return_sequences: Optional[int] = None,
  39. **model_kwargs,
  40. ) -> torch.LongTensor:
  41. """
  42. Generates sequences of token ids for models with a language modeling head.
  43. :param inputs: The input tokens to the model.
  44. :param do_sample: Whether to sample from the model predictions or take the argmax.
  45. :param temperature: The temperature to use for sampling.
  46. :param top_k: The number of results to return.
  47. :param top_p: The cumulative probability of results to return.
  48. :param num_beams: The number of beams to use for beam search.
  49. :param bos_token_id: The id of the beginning of sentence token.
  50. :param eos_token_id: The id of the end of sentence token.
  51. :param pad_token_id: The id of the padding token.
  52. :param max_new_tokens: The maximum number of tokens to generate.
  53. :param decoding_algorithm: The decoding algorithm to use.
  54. :param provided_constraints: A list of constraints to use.
  55. :param model_kwargs: Additional arguments to pass to the model.
  56. :param num_return_sequences: How many hypothesis from the beam will be in output.
  57. """
  58. assert (
  59. model_kwargs.get("logits_processor", None) is None
  60. ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
  61. assert (
  62. model_kwargs.get("logits_wrapper", None) is None
  63. ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
  64. assert (
  65. model_kwargs.get("stopping_criteria", None) is None
  66. ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
  67. if inputs is not None:
  68. assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
  69. prefix_length = 0 if inputs is None else inputs.size(1)
  70. prefix_length += self.config.pre_seq_len
  71. bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
  72. pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
  73. eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
  74. batch_size = inputs.size(0)
  75. assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
  76. if max_length is not None and max_new_tokens is None:
  77. max_new_tokens = max_length - prefix_length
  78. assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
  79. elif max_length is None and max_new_tokens is not None:
  80. max_length = prefix_length + max_new_tokens
  81. if inputs is None:
  82. assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
  83. inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
  84. if decoding_algorithm is None:
  85. if do_sample:
  86. decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
  87. elif num_beams is not None and num_beams > 1:
  88. decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
  89. else:
  90. decoding_algorithm = GreedyAlgorithm()
  91. if num_beams > 1:
  92. inputs = torch.cat([inputs] * num_beams, dim=0)
  93. if batch_size > 1:
  94. # TODO: resolve padding problem
  95. logger.warning(
  96. f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
  97. )
  98. if num_return_sequences is None:
  99. num_return_sequences = 1
  100. assert num_return_sequences <= num_beams, (
  101. f"You want more sequences than the beam has."
  102. " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
  103. )
  104. constraints = self._get_constraints(
  105. inputs=inputs,
  106. eos_token_id=eos_token_id,
  107. pad_token_id=pad_token_id,
  108. provided_constraints=provided_constraints,
  109. )
  110. with self.transformer.h.inference_session(max_length=max_length) as sess:
  111. outputs = []
  112. # Find samples with padded inputs.
  113. # They will be changed before all of the samples have right length.
  114. if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
  115. outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
  116. else:
  117. outputs += [inputs]
  118. last_token_id = None
  119. seq_idx = outputs[0].size(1)
  120. hypo_ids = torch.arange(outputs[0].size(0))
  121. while True:
  122. embs = self.transformer.word_embeddings(outputs[-1])
  123. intermediate_prompts = None
  124. if self.config.pre_seq_len > 0 and len(outputs) == 1:
  125. prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
  126. embs = torch.cat([prompts, embs], dim=1)
  127. embs = self.transformer.word_embeddings_layernorm(embs)
  128. hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
  129. hidden_state = self.transformer.ln_f(hidden_state)
  130. lm_logits = self.lm_head(hidden_state)
  131. for constraint in constraints:
  132. lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
  133. last_token_id, hypo_ids = decoding_algorithm(lm_logits)
  134. # If some samples were padded, change only these samples
  135. if seq_idx < inputs.size(1):
  136. pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
  137. last_token_id = (~pad_token_mask) * inputs[
  138. :, seq_idx : seq_idx + 1
  139. ] + pad_token_mask * last_token_id
  140. # TODO: refactor outputs
  141. if num_beams > 1:
  142. for i in range(len(outputs), 1, -1):
  143. outputs[i - 1] = outputs[i - 1][hypo_ids]
  144. outputs.append(last_token_id)
  145. seq_idx += 1
  146. if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
  147. break
  148. outputs = torch.cat(outputs, dim=-1)
  149. if num_beams > 1:
  150. pre_return_idx = [
  151. torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
  152. ]
  153. return_idx = torch.cat(pre_return_idx, dim=0)
  154. outputs = outputs[return_idx]
  155. return outputs
  156. def greedy_search(
  157. self,
  158. input_ids: torch.LongTensor,
  159. max_length: Optional[int] = None,
  160. pad_token_id: Optional[int] = None,
  161. eos_token_id: Optional[int] = None,
  162. provided_constraints: List[ABCBloomConstraint] = [],
  163. **model_kwargs,
  164. ) -> torch.LongTensor:
  165. """
  166. Generates sequences of token ids for models with a language modeling head. Uses greedy search.
  167. :param input_ids: The input tokens to the model.
  168. :param max_length: The maximum length of the sequence to generate.
  169. :param pad_token_id: The id of the padding token.
  170. :param eos_token_id: The id of the end of sentence token.
  171. :param provided_constraints: A list of constraints to use.
  172. """
  173. return self.generate(
  174. inputs=input_ids,
  175. max_new_tokens=max_length,
  176. pad_token_id=pad_token_id,
  177. eos_token_id=eos_token_id,
  178. decoding_algorithm=GreedyAlgorithm(),
  179. provided_constraints=provided_constraints,
  180. **model_kwargs,
  181. )
  182. def sample(
  183. self,
  184. input_ids: torch.LongTensor,
  185. temperature: float = 1.0,
  186. top_k: Optional[int] = None,
  187. top_p: Optional[float] = None,
  188. max_length: Optional[int] = None,
  189. pad_token_id: Optional[int] = None,
  190. eos_token_id: Optional[int] = None,
  191. provided_constraints: List[ABCBloomConstraint] = [],
  192. **model_kwargs,
  193. ) -> torch.LongTensor:
  194. """
  195. Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
  196. :param: input_ids: The input tokens to the model.
  197. :param: temperature: The temperature to use for sampling.
  198. :param: top_k: The number of samples to use for top_k sampling.
  199. :param: top_p: The probability of using top_p sampling.
  200. :param: max_length: The maximum length of the sequence to generate.
  201. :param: pad_token_id: The id of the padding token.
  202. :param: eos_token_id: The id of the end of sentence token.
  203. :param: provided_constraints: A list of constraints to use.
  204. :param: model_kwargs: Additional kwargs to pass to the model.
  205. """
  206. return self.generate(
  207. inputs=input_ids,
  208. max_new_tokens=max_length,
  209. pad_token_id=pad_token_id,
  210. eos_token_id=eos_token_id,
  211. decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
  212. provided_constraints=provided_constraints,
  213. **model_kwargs,
  214. )
  215. def beam_search(
  216. self,
  217. input_ids: torch.LongTensor,
  218. num_beams: int = 1,
  219. max_length: Optional[int] = None,
  220. pad_token_id: Optional[int] = None,
  221. eos_token_id: Optional[int] = None,
  222. provided_constraints: List[ABCBloomConstraint] = [],
  223. **model_kwargs,
  224. ) -> torch.LongTensor:
  225. """
  226. Generates sequences of token ids for models with a language modeling head. Uses beam search.
  227. :param input_ids: The input tokens to the model.
  228. :param num_beams: The number of beams to use.
  229. :param max_length: The maximum length of the sequence to generate.
  230. :param pad_token_id: The id of the padding token.
  231. :param eos_token_id: The id of the end of sentence token.
  232. :param provided_constraints: A list of constraints to use.
  233. :param: model_kwargs: Additional kwargs to pass to the model.
  234. """
  235. decoding_algorithm = BeamSearchAlgorithm(
  236. num_beams=num_beams,
  237. batch_size=input_ids.size(0),
  238. )
  239. return self.generate(
  240. inputs=input_ids,
  241. num_beams=num_beams,
  242. max_new_tokens=max_length,
  243. pad_token_id=pad_token_id,
  244. eos_token_id=eos_token_id,
  245. decoding_algorithm=decoding_algorithm,
  246. provided_constraints=provided_constraints,
  247. **model_kwargs,
  248. )
  249. def beam_sample(
  250. self,
  251. input_ids: torch.LongTensor,
  252. max_length: Optional[int] = None,
  253. pad_token_id: Optional[int] = None,
  254. eos_token_id: Optional[int] = None,
  255. provided_constraints: List[ABCBloomConstraint] = [],
  256. **model_kwargs,
  257. ) -> torch.LongTensor:
  258. raise NotImplementedError
  259. def group_beam_search(
  260. self,
  261. input_ids: torch.LongTensor,
  262. max_length: Optional[int] = None,
  263. pad_token_id: Optional[int] = None,
  264. eos_token_id: Optional[int] = None,
  265. provided_constraints: List[ABCBloomConstraint] = [],
  266. **model_kwargs,
  267. ) -> torch.LongTensor:
  268. raise NotImplementedError
  269. def _choose_sample_algorithm(
  270. self,
  271. temperature: float = 1.0,
  272. top_k: Optional[int] = None,
  273. top_p: Optional[float] = None,
  274. ) -> DecodingAlgorithm:
  275. if (top_k is not None) and (top_p is not None):
  276. raise ValueError("You have to provide only top_k or top_p for sampling")
  277. if top_k:
  278. return TopKAlgorithm(top_k, temperature)
  279. elif top_p:
  280. return NucleusAlgorithm(top_p, temperature)
  281. def _get_constraints(
  282. self,
  283. inputs: Optional[torch.Tensor] = None,
  284. eos_token_id: Optional[int] = None,
  285. pad_token_id: Optional[int] = None,
  286. provided_constraints: List[ABCBloomConstraint] = [],
  287. ) -> List[ABCBloomConstraint]:
  288. constraints = []
  289. constraints.extend(provided_constraints)
  290. constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
  291. return constraints