generation_constraints.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. from abc import ABC
  3. class ABCBloomConstraint(ABC):
  4. """
  5. Base class of all kind of decoding constraints. It can be used to implement a new constraint.
  6. """
  7. def __init__(self) -> None:
  8. pass
  9. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  10. """
  11. This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
  12. :param tokens_id: The token id of the last choosen token.
  13. :param logits: The logits from the Bloom model.
  14. :param hypo_ids: The hypothesis ids of the last tokens.
  15. """
  16. pass
  17. class MaxNewTokensConstraint(ABCBloomConstraint):
  18. """
  19. Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
  20. Args:
  21. prefix: The prefix of the sequence.
  22. max_new_tokens: The maximum number of tokens that can be generated after the prefix.
  23. eos_token_id: The id of the end of sentence token.
  24. pad_token_id: The id of the padding token.
  25. min_logits: The minimum logits that can be generated. Default: -1e6.
  26. """
  27. def __init__(self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
  28. self.max_new_tokens = max_new_tokens
  29. self.current_generated_tokens = None
  30. self.eos_token_id = eos_token_id
  31. self.min_logits = min_logits
  32. self.current_generated_tokens = -(prefix == pad_token_id).sum(-1)
  33. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  34. if tokens_id is not None:
  35. self.current_generated_tokens += 1
  36. mask = (self.current_generated_tokens > self.max_new_tokens).unsqueeze(1)
  37. logits += self.min_logits * mask
  38. logits[mask[:, 0], self.eos_token_id] = 0
  39. return logits
  40. class EosConstraint(ABCBloomConstraint):
  41. """
  42. This constrained repeats EOS token if it was generated on the previous step.
  43. Args:
  44. prefix: The prefix of the sequence.
  45. eos_token_id: The id of the end of sentence token.
  46. pad_token_id: The id of the padding token.
  47. min_logits: The minimum logits that can be generated. Default: -1e6.
  48. """
  49. def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e6) -> None:
  50. self.eos_token_id = eos_token_id
  51. self.min_logits = min_logits
  52. self.past_tokens = None
  53. self.wait_until_starting = (prefix == pad_token_id).sum(-1).unsqueeze(1)
  54. def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
  55. if self.past_tokens is not None:
  56. mask = ((self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id))
  57. logits += self.min_logits * mask
  58. logits[mask[:, 0], self.eos_token_id] = 0
  59. if tokens_id is not None:
  60. self.past_tokens = tokens_id
  61. self.wait_until_starting -= 1
  62. return logits