generation_algorithms.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from abc import ABC
  2. from typing import Tuple
  3. import torch
  4. TokenIds = torch.Tensor
  5. HypoIds = torch.Tensor
  6. class DecodingAlgorithm(ABC):
  7. """
  8. An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
  9. """
  10. def __init__(self) -> None:
  11. pass
  12. def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  13. """
  14. :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
  15. :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
  16. """
  17. pass
  18. class GreedyAlgorithm(DecodingAlgorithm):
  19. """
  20. The simpliest algorithm for decoding. It selects the most probable token.
  21. """
  22. def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  23. """
  24. Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
  25. """
  26. return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
  27. class SamplingAlgorithm(DecodingAlgorithm):
  28. def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  29. """
  30. :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
  31. :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
  32. :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
  33. """
  34. logits[indices_to_remove] = -float("Inf")
  35. probs = torch.softmax(logits / self.temperature, -1)
  36. return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
  37. class TopKAlgorithm(SamplingAlgorithm):
  38. def __init__(self, top_k: int, temperature: float = 1.0) -> None:
  39. self.top_k = top_k
  40. self.temperature = temperature
  41. def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  42. indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
  43. return self.sample(logits, indices_to_remove)
  44. class NucleusAlgorithm(SamplingAlgorithm):
  45. def __init__(self, top_p: float, temperature: float = 1.0) -> None:
  46. self.top_p = top_p
  47. self.temperature = temperature
  48. def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  49. sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
  50. probs = torch.softmax(sorted_logits / self.temperature, -1)
  51. cumulative_probs = torch.cumsum(probs, dim=-1)
  52. sorted_indices_to_remove = cumulative_probs > self.top_p
  53. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  54. sorted_indices_to_remove[..., 0] = False
  55. indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
  56. indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
  57. return self.sample(logits, indices_to_remove)
  58. class BeamSearchAlgorithm(DecodingAlgorithm):
  59. def __init__(self, num_beams: int, batch_size: int) -> None:
  60. self.num_beams = num_beams
  61. self._cur_num_beams = 1
  62. self.batch_size = batch_size
  63. self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
  64. def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
  65. sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
  66. probs = torch.softmax(sorted_logits, -1)
  67. new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
  68. for batch_idx in range(self.batch_size):
  69. for cur_beam_idx in range(self._cur_num_beams):
  70. for new_beam_idx in range(self.num_beams):
  71. logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
  72. new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
  73. self._cur_num_beams = self.num_beams
  74. new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
  75. new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
  76. self._logits = new_sorted_logits[:, :self.num_beams]
  77. result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
  78. result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
  79. return result_tokens.unsqueeze(-1), result_hypos