ソースを参照

Draft repetition penalty

Aleksandr Borzunov 2 年 前
コミット
e191ce2f4e

+ 1 - 1
src/petals/client/inference_session.py

@@ -171,7 +171,7 @@ class InferenceSession:
         self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
         self._max_length = max_length
-        self.last_token_id = None
+        self.token_ids = []
 
     @property
     def position(self) -> int:

+ 24 - 9
src/petals/client/remote_generation.py

@@ -10,6 +10,7 @@ from petals.utils.generation_algorithms import (
     DecodingAlgorithm,
     GreedyAlgorithm,
     NucleusAlgorithm,
+    RepetitionPenaltyAlgorithm,
     SamplingAlgorithm,
     TopKAlgorithm,
 )
@@ -48,6 +49,7 @@ class RemoteGenerationMixin:
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
+        repetition_penalty: Optional[float] = None,
         num_beams: Optional[int] = 1,
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
@@ -69,6 +71,7 @@ class RemoteGenerationMixin:
         :param temperature: The temperature to use for sampling.
         :param top_k: The number of results to return.
         :param top_p: The cumulative probability of results to return.
+        :param repetition_penalty: Repetition penalty (1.0 means no penalty). See https://arxiv.org/pdf/1909.05858.pdf
         :param num_beams: The number of beams to use for beam search.
         :param bos_token_id: The id of the beginning of sentence token.
         :param eos_token_id: The id of the end of sentence token.
@@ -111,11 +114,11 @@ class RemoteGenerationMixin:
 
         if inputs is not None:
             assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
-            if session is not None and session.last_token_id is not None:
-                inputs = torch.cat([session.last_token_id, inputs], dim=1)
+            if session is not None and session.token_ids:
+                inputs = torch.cat([session.token_ids[-1], inputs], dim=1)
         else:
-            if session is not None and session.last_token_id is not None:
-                inputs = session.last_token_id
+            if session is not None and session.token_ids:
+                inputs = session.token_ids[-1]
             else:
                 assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
                 inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
@@ -123,12 +126,14 @@ class RemoteGenerationMixin:
 
         if decoding_algorithm is None:
             if do_sample:
-                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p, repetition_penalty)
             elif num_beams is not None and num_beams > 1:
                 decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
             else:
-                if top_k is not None or top_p is not None:
-                    logger.warning("You passed top_k or top_p but did pass do_sample=True. Running greedy sampling")
+                if top_k is not None or top_p is not None or repetition_penalty is not None:
+                    logger.warning(
+                        "You passed top_k, top_p, or repetition_penalty but did pass do_sample=True. Running greedy sampling"
+                    )
                 decoding_algorithm = GreedyAlgorithm()
 
         if num_beams > 1:
@@ -160,6 +165,12 @@ class RemoteGenerationMixin:
         else:
             context_manager = contextlib.nullcontext(session)  # Doesn't actually enter session or exit from it
         with context_manager as session:
+            if session.token_ids:
+                if inputs.shape[1] >= 2:
+                    session.token_ids.append(inputs[:, 1:])
+            else:
+                session.token_ids.append(inputs)
+
             outputs = []
             # Find samples with padded inputs.
             # They will be changed before all of the samples have right length.
@@ -183,7 +194,8 @@ class RemoteGenerationMixin:
 
                 for constraint in constraints:
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
-                last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+                token_ids = torch.cat(session.token_ids, dim=1) if session.token_ids else torch.empty(batch_size, 0, dtype=torch.int64)
+                last_token_id, hypo_ids = decoding_algorithm(token_ids, lm_logits)
 
                 # If some samples were padded, change only these samples
                 if seq_idx < inputs.size(1):
@@ -198,7 +210,7 @@ class RemoteGenerationMixin:
                         outputs[i - 1] = outputs[i - 1][hypo_ids]
 
                 outputs.append(last_token_id)
-                session.last_token_id = last_token_id
+                session.token_ids.append(last_token_id)
                 seq_idx += 1
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
                     break
@@ -342,6 +354,7 @@ class RemoteGenerationMixin:
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
+        repetition_penalty: Optional[float] = None,
     ) -> DecodingAlgorithm:
         if (top_k is not None) and (top_p is not None):
             raise ValueError("You have to provide only top_k or top_p for sampling")
@@ -349,6 +362,8 @@ class RemoteGenerationMixin:
             return TopKAlgorithm(top_k, temperature)
         elif top_p is not None:
             return NucleusAlgorithm(top_p, temperature)
+        elif repetition_penalty is not None:
+            return RepetitionPenaltyAlgorithm(repetition_penalty, temperature)
         else:
             return SamplingAlgorithm(temperature)
 

+ 20 - 6
src/petals/utils/generation_algorithms.py

@@ -14,7 +14,7 @@ class DecodingAlgorithm(ABC):
     """
 
     @abstractmethod
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
         :return: A tuple of selected token ids and corresponding hypotheses.
@@ -28,7 +28,7 @@ class GreedyAlgorithm(DecodingAlgorithm):
     The simplest algorithm for decoding. It selects the most probable token.
     """
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         """
         Returns the most probable token. The second returned object is always a range of integers
         from 0 to batch_size - 1.
@@ -51,7 +51,7 @@ class SamplingAlgorithm(DecodingAlgorithm):
         probs = torch.softmax(logits / self.temperature, -1)
         return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
         return self.sample(logits, indices_to_remove)
 
@@ -61,7 +61,7 @@ class TopKAlgorithm(SamplingAlgorithm):
         super().__init__(temperature=temperature)
         self.top_k = top_k
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
         return self.sample(logits, indices_to_remove)
 
@@ -71,7 +71,7 @@ class NucleusAlgorithm(SamplingAlgorithm):
         super().__init__(temperature=temperature)
         self.top_p = top_p
 
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
         probs = torch.softmax(sorted_logits / self.temperature, -1)
         cumulative_probs = torch.cumsum(probs, dim=-1)
@@ -82,6 +82,20 @@ class NucleusAlgorithm(SamplingAlgorithm):
         return self.sample(logits, indices_to_remove)
 
 
+class RepetitionPenaltyAlgorithm(SamplingAlgorithm):
+    def __init__(self, repetition_penalty: float, temperature: float = 1.0) -> None:
+        super().__init__(temperature=temperature)
+        self.repetition_penalty = repetition_penalty
+
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        score = torch.gather(logits, -1, token_ids)
+        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+        score = torch.where(score < 0, score * self.repetition_penalty, score / self.repetition_penalty)
+        logits.scatter_(-1, token_ids, score)
+
+        return super().__call__(token_ids, logits)
+
+
 class BeamSearchAlgorithm(DecodingAlgorithm):
     def __init__(self, num_beams: int, batch_size: int) -> None:
         self.num_beams = num_beams
@@ -90,7 +104,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
 
         self._batch_beams = [list() for _ in range(batch_size)]
 
-    def __call__(self, logits: torch.Tensor):
+    def __call__(self, token_ids: torch.LongTensor, logits: torch.Tensor):
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.log_softmax(sorted_logits, -1)