ソースを参照

working version

Artem Chumachenko 2 年 前
コミット
eb1334e567

+ 11 - 10
src/client/remote_generation.py

@@ -10,7 +10,7 @@ from src.utils.generation_algorithms import (
     NucleusAlgorithm,
     TopKAlgorithm
 )
-from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 
 
 class RemoteGenerationMixin:
@@ -19,8 +19,9 @@ class RemoteGenerationMixin:
     The class exposes can be used for:
         - *greedy decoding*.
         - *multinomial sampling*.
+        - *beam-search decoding*
 
-    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
     """
 
     @torch.no_grad()
@@ -31,7 +32,7 @@ class RemoteGenerationMixin:
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
-        num_beams: Optional[int] = None,
+        num_beams: Optional[int] = 1,
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
@@ -49,6 +50,7 @@ class RemoteGenerationMixin:
         :param temperature: The temperature to use for sampling.
         :param top_k: The number of results to return.
         :param top_p: The cumulative probability of results to return.
+        :param num_beams: The number of beams to use for beam search.
         :param bos_token_id: The id of the beginning of sentence token.
         :param eos_token_id: The id of the end of sentence token.
         :param pad_token_id: The id of the padding token.
@@ -102,12 +104,13 @@ class RemoteGenerationMixin:
             inputs=inputs,
             eos_token_id=eos_token_id,
             pad_token_id=pad_token_id,
-            max_new_tokens=max_new_tokens,
             provided_constraints=provided_constraints,
         )
 
         with self.transformer.h.inference_session(max_length=max_length) as sess:
             outputs = []
+            # Find samples with padded inputs.
+            # They will be changed before all of the samples have right length.
             if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
                 outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
             else:
@@ -129,13 +132,15 @@ class RemoteGenerationMixin:
                 for constraint in constraints:
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
-                if seq_idx < inputs.size(1):  # TODO: why is it not a constraint?
+
+                # If samples have padded, so changes only them
+                if seq_idx < inputs.size(1):
                     pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
                     last_token_id = (~pad_token_mask) * inputs[
                         :, seq_idx : seq_idx + 1
                     ] + pad_token_mask * last_token_id
 
-                if torch.all(last_token_id == eos_token_id):
+                if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
                     break
 
                 outputs.append(last_token_id)
@@ -242,7 +247,6 @@ class RemoteGenerationMixin:
             provided_constraints=provided_constraints,
             **model_kwargs,
         )
-        raise NotImplementedError
 
     def beam_sample(
         self,
@@ -284,12 +288,9 @@ class RemoteGenerationMixin:
         inputs: Optional[torch.Tensor] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
     ) -> List[ABCBloomConstraint]:
         constraints = []
         constraints.extend(provided_constraints)
-        if max_new_tokens is not None:
-            constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
         constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
         return constraints

+ 16 - 12
src/utils/generation_algorithms.py

@@ -77,23 +77,27 @@ class NucleusAlgorithm(SamplingAlgorithm):
 class BeamSearchAlgorithm(DecodingAlgorithm):
     def __init__(self, num_beams: int, batch_size: int) -> None:
         self.num_beams = num_beams
+        self._cur_num_beams = 1
         self.batch_size = batch_size
 
-        self._logits = torch.zeros((self.num_beams * self.batch_size))
+        self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
     
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.softmax(sorted_logits, -1)
-        # self.batch_zise == 1
+        
         new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
-        for beam_idx in range(self.num_beams):
-            for token_idx in range(self.num_beams):
-                new_logits[beam_idx * self.num_beam + token_idx] += probs[beam_idx, token_idx]
+        for batch_idx in range(self.batch_size):
+            for cur_beam_idx in range(self._cur_num_beams):
+                for new_beam_idx in range(self.num_beams):
+                    logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
+                    new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
+        self._cur_num_beams = self.num_beams
+
         new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
-        self._logits = new_sorted_logits[:self.num_beams]
-        result_tokens = []
-        result_hypos = []
-        for beam_idx in range(self.num_beams):
-            result_tokens.append(sorted_indices[new_sorted_indices[beam_idx] % self.num_beams])
-            result_hypos.append(new_sorted_indices[beam_idx] // self.num_beams)
-        return torch.stack(result_tokens, dim=1), torch.stack(result_hypos, dim=1)
+        new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
+        self._logits = new_sorted_logits[:, :self.num_beams]
+        result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
+        result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
+
+        return result_tokens.unsqueeze(-1), result_hypos

+ 0 - 33
src/utils/generation_constraints.py

@@ -21,39 +21,6 @@ class ABCBloomConstraint(ABC):
         pass
 
 
-class MaxNewTokensConstraint(ABCBloomConstraint):
-    """
-    Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
-
-    Args:
-        prefix: The prefix of the sequence.
-        max_new_tokens: The maximum number of tokens that can be generated after the prefix.
-        eos_token_id: The id of the end of sentence token.
-        pad_token_id: The id of the padding token.
-        min_logits: The minimum logits that can be generated. Default: -1e6.
-    """
-
-    def __init__(
-        self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
-    ) -> None:
-        self.max_new_tokens = max_new_tokens
-        self.current_generated_tokens = None
-        self.eos_token_id = eos_token_id
-        self.min_logits = min_logits
-
-        max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
-        self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
-
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
-        if tokens_id is not None:
-            self.current_generated_tokens += 1
-
-        mask = self.current_generated_tokens >= self.max_new_tokens
-        logits += self.min_logits * mask
-        logits[mask[:, 0], self.eos_token_id] = 0
-        return logits
-
-
 class EosConstraint(ABCBloomConstraint):
     """
     This constrained repeats EOS token if it was generated on the previous step.