ソースを参照

Add batch_size=1 beam_search

Artem Chumachenko 2 年 前
コミット
346827bb00

+ 40 - 2
src/client/remote_generation.py

@@ -3,7 +3,13 @@ from typing import List, Optional
 import torch
 import torch.nn.functional as F
 
-from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
+from src.utils.generation_algorithms import (
+    BeamSearchAlgorithm,
+    DecodingAlgorithm,
+    GreedyAlgorithm,
+    NucleusAlgorithm,
+    TopKAlgorithm
+)
 from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
 
 
@@ -25,6 +31,7 @@ class RemoteGenerationMixin:
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
+        num_beams: Optional[int] = None,
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
@@ -78,14 +85,19 @@ class RemoteGenerationMixin:
 
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
-            inputs = torch.tensor([[bos_token_id]])
+            inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
 
         if decoding_algorithm is None:
             if do_sample:
                 decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+            elif num_beams is not None and num_beams > 1:
+                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=inputs.size(0))
             else:
                 decoding_algorithm = GreedyAlgorithm()
 
+        if num_beams > 1:
+            inputs = torch.cat([inputs] * num_beams, dim=0)
+
         constraints = self._get_constraints(
             inputs=inputs,
             eos_token_id=eos_token_id,
@@ -198,12 +210,38 @@ class RemoteGenerationMixin:
     def beam_search(
         self,
         input_ids: torch.LongTensor,
+        num_beams: int = 1,
         max_length: Optional[int] = None,
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
         **model_kwargs,
     ) -> torch.LongTensor:
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses beam search.
+
+        :param input_ids: The input tokens to the model.
+        :param num_beams: The number of beams to use.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+        decoding_algorithm = BeamSearchAlgorithm(
+            num_beams=num_beams,
+            bath_size=input_ids.size(0),
+        )
+        return self.generate(
+            inputs=input_ids,
+            num_beams=num_beams,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=decoding_algorithm,
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
         raise NotImplementedError
 
     def beam_sample(

+ 1 - 0
src/server/backend.py

@@ -59,6 +59,7 @@ class TransformerBackend(ModuleBackend):
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
                 if not is_dummy(hypo_ids):
+                    assert hypo_ids.shape[0] == cache.shape[1]
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")

+ 23 - 2
src/utils/generation_algorithms.py

@@ -48,7 +48,6 @@ class SamplingAlgorithm(DecodingAlgorithm):
 
 
 class TopKAlgorithm(SamplingAlgorithm):
-    # TODO: Add NumHypos, maxBatchSize
     def __init__(self, top_k: int, temperature: float = 1.0) -> None:
         self.top_k = top_k
         self.temperature = temperature
@@ -75,4 +74,26 @@ class NucleusAlgorithm(SamplingAlgorithm):
         return self.sample(logits, indices_to_remove)
 
 
-# TODO: In generate function we need to check usage of top_k or sampling algorithm
+class BeamSearchAlgorithm(DecodingAlgorithm):
+    def __init__(self, num_beams: int, batch_size: int) -> None:
+        self.num_beams = num_beams
+        self.batch_size = batch_size
+
+        self._logits = torch.zeros((self.num_beams * self.batch_size))
+    
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.softmax(sorted_logits, -1)
+        # self.batch_zise == 1
+        new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
+        for beam_idx in range(self.num_beams):
+            for token_idx in range(self.num_beams):
+                new_logits[beam_idx * self.num_beam + token_idx] += probs[beam_idx, token_idx]
+        new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
+        self._logits = new_sorted_logits[:self.num_beams]
+        result_tokens = []
+        result_hypos = []
+        for beam_idx in range(self.num_beams):
+            result_tokens.append(sorted_indices[new_sorted_indices[beam_idx] % self.num_beams])
+            result_hypos.append(new_sorted_indices[beam_idx] // self.num_beams)
+        return torch.stack(result_tokens, dim=1), torch.stack(result_hypos, dim=1)