|
@@ -1,4 +1,4 @@
|
|
|
-from abc import ABC
|
|
|
+from abc import ABC, abstractmethod
|
|
|
from typing import Tuple
|
|
|
|
|
|
import torch
|
|
@@ -9,16 +9,16 @@ HypoIds = torch.Tensor
|
|
|
|
|
|
class DecodingAlgorithm(ABC):
|
|
|
"""
|
|
|
- An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
|
|
|
+ An abstract class for decoding algorithms. Describes the base function of those algorithms:
|
|
|
+ they have to select new tokens and provide the corresponding hypotheses.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self) -> None:
|
|
|
- pass
|
|
|
-
|
|
|
+ @abstractmethod
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
"""
|
|
|
:param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
|
|
|
- :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
|
|
|
+ :return: A tuple of selected token ids and corresponding hypotheses.
|
|
|
+ The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
|
|
|
"""
|
|
|
pass
|
|
|
|
|
@@ -30,27 +30,36 @@ class GreedyAlgorithm(DecodingAlgorithm):
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
"""
|
|
|
- Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
|
|
|
+ Returns the most probable token. The second returned object is always a range of integers
|
|
|
+ from 0 to batch_size - 1.
|
|
|
"""
|
|
|
return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
|
|
|
|
|
|
|
|
|
class SamplingAlgorithm(DecodingAlgorithm):
|
|
|
+ def __init__(self, temperature: float = 1.0):
|
|
|
+ self.temperature = temperature
|
|
|
+
|
|
|
def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
"""
|
|
|
:param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
|
|
|
:param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
|
|
|
- :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
|
|
|
+ :return: A tuple of selected token ids and corresponding hypotheses.
|
|
|
+ The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
|
|
|
"""
|
|
|
logits[indices_to_remove] = -float("Inf")
|
|
|
probs = torch.softmax(logits / self.temperature, -1)
|
|
|
return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
|
|
|
|
|
|
+ def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
+ indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
|
|
|
+ return self.sample(logits, indices_to_remove)
|
|
|
+
|
|
|
|
|
|
class TopKAlgorithm(SamplingAlgorithm):
|
|
|
def __init__(self, top_k: int, temperature: float = 1.0) -> None:
|
|
|
+ super().__init__(temperature=temperature)
|
|
|
self.top_k = top_k
|
|
|
- self.temperature = temperature
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
|
|
@@ -59,18 +68,17 @@ class TopKAlgorithm(SamplingAlgorithm):
|
|
|
|
|
|
class NucleusAlgorithm(SamplingAlgorithm):
|
|
|
def __init__(self, top_p: float, temperature: float = 1.0) -> None:
|
|
|
+ super().__init__(temperature=temperature)
|
|
|
self.top_p = top_p
|
|
|
- self.temperature = temperature
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
|
|
|
probs = torch.softmax(sorted_logits / self.temperature, -1)
|
|
|
cumulative_probs = torch.cumsum(probs, dim=-1)
|
|
|
- sorted_indices_to_remove = cumulative_probs > self.top_p
|
|
|
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
- sorted_indices_to_remove[..., 0] = False
|
|
|
- indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
|
|
|
- indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
|
|
+
|
|
|
+ sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
|
|
|
+
|
|
|
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
return self.sample(logits, indices_to_remove)
|
|
|
|
|
|
|