generation_constraints.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from abc import ABC
  2. import torch
  3. class ABCBloomConstraint(ABC):
  4. """
  5. Base class of all kind of decoding constraints. It can be used to implement a new constraint.
  6. """
  7. def __init__(self) -> None:
  8. pass
  9. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  10. """
  11. This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
  12. :param tokens_id: The token id of the last choosen token.
  13. :param logits: The logits from the Bloom model.
  14. :param hypo_ids: The hypothesis ids of the last tokens.
  15. """
  16. pass
  17. class MaxNewTokensConstraint(ABCBloomConstraint):
  18. """
  19. Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
  20. Args:
  21. prefix: The prefix of the sequence.
  22. max_new_tokens: The maximum number of tokens that can be generated after the prefix.
  23. eos_token_id: The id of the end of sentence token.
  24. pad_token_id: The id of the padding token.
  25. min_logits: The minimum logits that can be generated. Default: -1e6.
  26. """
  27. def __init__(
  28. self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
  29. ) -> None:
  30. self.max_new_tokens = max_new_tokens
  31. self.current_generated_tokens = None
  32. self.eos_token_id = eos_token_id
  33. self.min_logits = min_logits
  34. max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
  35. self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
  36. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  37. if tokens_id is not None:
  38. self.current_generated_tokens += 1
  39. mask = self.current_generated_tokens >= self.max_new_tokens
  40. logits += self.min_logits * mask
  41. logits[mask[:, 0], self.eos_token_id] = 0
  42. return logits
  43. class EosConstraint(ABCBloomConstraint):
  44. """
  45. This constrained repeats EOS token if it was generated on the previous step.
  46. Args:
  47. prefix: The prefix of the sequence.
  48. eos_token_id: The id of the end of sentence token.
  49. pad_token_id: The id of the padding token.
  50. min_logits: The minimum logits that can be generated. Default: -1e6.
  51. """
  52. def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
  53. self.eos_token_id = eos_token_id
  54. self.min_logits = min_logits
  55. self.past_tokens = None
  56. self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
  57. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  58. if self.past_tokens is not None:
  59. mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
  60. logits += self.min_logits * mask
  61. logits[mask[:, 0], self.eos_token_id] = 0
  62. if tokens_id is not None:
  63. self.past_tokens = tokens_id
  64. self.wait_until_starting -= 1
  65. return logits