Просмотр исходного кода

Add Beam Search decoding algorithm (#87)

Add beam_search
Artem Chumachenko 2 лет назад
Родитель
Сommit
fdb3583a8c

+ 85 - 14
src/client/remote_generation.py

@@ -1,10 +1,18 @@
 from typing import List, Optional
 
 import torch
-import torch.nn.functional as F
+from hivemind.utils.logging import get_logger
 
-from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
-from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
+from src.utils.generation_algorithms import (
+    BeamSearchAlgorithm,
+    DecodingAlgorithm,
+    GreedyAlgorithm,
+    NucleusAlgorithm,
+    TopKAlgorithm,
+)
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
+
+logger = get_logger(__file__)
 
 
 class RemoteGenerationMixin:
@@ -13,8 +21,9 @@ class RemoteGenerationMixin:
     The class exposes can be used for:
         - *greedy decoding*.
         - *multinomial sampling*.
+        - *beam-search decoding*
 
-    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
     """
 
     @torch.no_grad()
@@ -25,6 +34,7 @@ class RemoteGenerationMixin:
         temperature: float = 1.0,
         top_k: Optional[int] = None,
         top_p: Optional[float] = None,
+        num_beams: Optional[int] = 1,
         bos_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
@@ -32,6 +42,7 @@ class RemoteGenerationMixin:
         max_new_tokens: Optional[int] = None,
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
+        num_return_sequences: Optional[int] = None,
         **model_kwargs,
     ) -> torch.LongTensor:
         """
@@ -42,6 +53,7 @@ class RemoteGenerationMixin:
         :param temperature: The temperature to use for sampling.
         :param top_k: The number of results to return.
         :param top_p: The cumulative probability of results to return.
+        :param num_beams: The number of beams to use for beam search.
         :param bos_token_id: The id of the beginning of sentence token.
         :param eos_token_id: The id of the end of sentence token.
         :param pad_token_id: The id of the padding token.
@@ -49,6 +61,7 @@ class RemoteGenerationMixin:
         :param decoding_algorithm: The decoding algorithm to use.
         :param provided_constraints: A list of constraints to use.
         :param model_kwargs: Additional arguments to pass to the model.
+        :param num_return_sequences: How many hypothesis from the beam will be in output.
         """
 
         assert (
@@ -69,6 +82,8 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
+        batch_size = inputs.size(0)
+
         assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
         if max_length is not None and max_new_tokens is None:
             max_new_tokens = max_length - prefix_length
@@ -78,24 +93,43 @@ class RemoteGenerationMixin:
 
         if inputs is None:
             assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
-            inputs = torch.tensor([[bos_token_id]])
+            inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
 
         if decoding_algorithm is None:
             if do_sample:
                 decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
+            elif num_beams is not None and num_beams > 1:
+                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
             else:
                 decoding_algorithm = GreedyAlgorithm()
 
+        if num_beams > 1:
+            inputs = torch.cat([inputs] * num_beams, dim=0)
+            if batch_size > 1:
+                # TODO: resolve padding problem
+                logger.warning(
+                    f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
+                )
+
+        if num_return_sequences is None:
+            num_return_sequences = 1
+
+        assert num_return_sequences <= num_beams, (
+            f"You want more sequences than the beam has."
+            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
+        )
+
         constraints = self._get_constraints(
             inputs=inputs,
             eos_token_id=eos_token_id,
             pad_token_id=pad_token_id,
-            max_new_tokens=max_new_tokens,
             provided_constraints=provided_constraints,
         )
 
         with self.transformer.h.inference_session(max_length=max_length) as sess:
             outputs = []
+            # Find samples with padded inputs.
+            # They will be changed before all of the samples have right length.
             if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
                 outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
             else:
@@ -117,19 +151,34 @@ class RemoteGenerationMixin:
                 for constraint in constraints:
                     lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
                 last_token_id, hypo_ids = decoding_algorithm(lm_logits)
-                if seq_idx < inputs.size(1):  # TODO: why is it not a constraint?
+
+                # If some samples were padded, change only these samples
+                if seq_idx < inputs.size(1):
                     pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
                     last_token_id = (~pad_token_mask) * inputs[
                         :, seq_idx : seq_idx + 1
                     ] + pad_token_mask * last_token_id
 
-                if torch.all(last_token_id == eos_token_id):
-                    break
+                # TODO: refactor outputs
+                if num_beams > 1:
+                    for i in range(len(outputs), 1, -1):
+                        outputs[i - 1] = outputs[i - 1][hypo_ids]
 
                 outputs.append(last_token_id)
                 seq_idx += 1
+                if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
+                    break
+
+        outputs = torch.cat(outputs, dim=-1)
 
-        return torch.cat(outputs, dim=-1)
+        if num_beams > 1:
+            pre_return_idx = [
+                torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
+            ]
+            return_idx = torch.cat(pre_return_idx, dim=0)
+            outputs = outputs[return_idx]
+
+        return outputs
 
     def greedy_search(
         self,
@@ -198,13 +247,38 @@ class RemoteGenerationMixin:
     def beam_search(
         self,
         input_ids: torch.LongTensor,
+        num_beams: int = 1,
         max_length: Optional[int] = None,
         pad_token_id: Optional[int] = None,
         eos_token_id: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
         **model_kwargs,
     ) -> torch.LongTensor:
-        raise NotImplementedError
+        """
+        Generates sequences of token ids for models with a language modeling head. Uses beam search.
+
+        :param input_ids: The input tokens to the model.
+        :param num_beams: The number of beams to use.
+        :param max_length: The maximum length of the sequence to generate.
+        :param pad_token_id: The id of the padding token.
+        :param eos_token_id: The id of the end of sentence token.
+        :param provided_constraints: A list of constraints to use.
+        :param: model_kwargs: Additional kwargs to pass to the model.
+        """
+        decoding_algorithm = BeamSearchAlgorithm(
+            num_beams=num_beams,
+            batch_size=input_ids.size(0),
+        )
+        return self.generate(
+            inputs=input_ids,
+            num_beams=num_beams,
+            max_new_tokens=max_length,
+            pad_token_id=pad_token_id,
+            eos_token_id=eos_token_id,
+            decoding_algorithm=decoding_algorithm,
+            provided_constraints=provided_constraints,
+            **model_kwargs,
+        )
 
     def beam_sample(
         self,
@@ -246,12 +320,9 @@ class RemoteGenerationMixin:
         inputs: Optional[torch.Tensor] = None,
         eos_token_id: Optional[int] = None,
         pad_token_id: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
     ) -> List[ABCBloomConstraint]:
         constraints = []
         constraints.extend(provided_constraints)
-        if max_new_tokens is not None:
-            constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
         constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
         return constraints

+ 1 - 0
src/server/backend.py

@@ -59,6 +59,7 @@ class TransformerBackend(ModuleBackend):
             with self.memory_cache.use_cache(attention_cache_handle) as cache:
                 assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
                 if not is_dummy(hypo_ids):
+                    assert hypo_ids.shape[0] == cache.shape[1]
                     cache[:, :] = cache[:, hypo_ids]  # in-place reorder cache by hypo ids
                 layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
                 logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")

+ 45 - 2
src/utils/generation_algorithms.py

@@ -48,7 +48,6 @@ class SamplingAlgorithm(DecodingAlgorithm):
 
 
 class TopKAlgorithm(SamplingAlgorithm):
-    # TODO: Add NumHypos, maxBatchSize
     def __init__(self, top_k: int, temperature: float = 1.0) -> None:
         self.top_k = top_k
         self.temperature = temperature
@@ -75,4 +74,48 @@ class NucleusAlgorithm(SamplingAlgorithm):
         return self.sample(logits, indices_to_remove)
 
 
-# TODO: In generate function we need to check usage of top_k or sampling algorithm
+class BeamSearchAlgorithm(DecodingAlgorithm):
+    def __init__(self, num_beams: int, batch_size: int) -> None:
+        self.num_beams = num_beams
+        self._cur_num_beams = 1
+        self.batch_size = batch_size
+
+        self._batch_beams = [list() for _ in range(batch_size)]
+
+    def __call__(self, logits: torch.Tensor):
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
+        probs = torch.log_softmax(sorted_logits, -1)
+
+        if len(self._batch_beams[0]) > 0:
+            for batch_idx in range(self.batch_size):
+                new_beams = []
+                cur_beams = self._batch_beams[batch_idx]
+                for beam_idx in range(len(cur_beams)):
+                    probs_idx = batch_idx + beam_idx * self.batch_size
+                    new_beam = cur_beams[beam_idx]
+                    for hypo_idx in range(self.num_beams):
+                        new_beams.append(
+                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
+                        )
+                self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
+        else:
+            for batch_idx in range(self.batch_size):
+                for beam_idx in range(self.num_beams):
+                    self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
+
+        return_hypos = []
+        return_tokens = []
+        for batch_idx in range(self.batch_size):
+            cur_beam = self._batch_beams[batch_idx]
+            return_hypos.append(list())
+            return_tokens.append(list())
+            for beam in cur_beam:
+                beam_idx = beam[1] // self.num_beams
+                hypo_idx = batch_idx + beam_idx * self.batch_size
+                token_idx = beam[1] % self.num_beams
+                return_hypos[-1].append(hypo_idx)
+                return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
+        return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
+        return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
+
+        return torch.tensor(return_tokens), torch.tensor(return_hypos)

+ 0 - 33
src/utils/generation_constraints.py

@@ -21,39 +21,6 @@ class ABCBloomConstraint(ABC):
         pass
 
 
-class MaxNewTokensConstraint(ABCBloomConstraint):
-    """
-    Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
-
-    Args:
-        prefix: The prefix of the sequence.
-        max_new_tokens: The maximum number of tokens that can be generated after the prefix.
-        eos_token_id: The id of the end of sentence token.
-        pad_token_id: The id of the padding token.
-        min_logits: The minimum logits that can be generated. Default: -1e6.
-    """
-
-    def __init__(
-        self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
-    ) -> None:
-        self.max_new_tokens = max_new_tokens
-        self.current_generated_tokens = None
-        self.eos_token_id = eos_token_id
-        self.min_logits = min_logits
-
-        max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
-        self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
-
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
-        if tokens_id is not None:
-            self.current_generated_tokens += 1
-
-        mask = self.current_generated_tokens >= self.max_new_tokens
-        logits += self.min_logits * mask
-        logits[mask[:, 0], self.eos_token_id] = 0
-        return logits
-
-
 class EosConstraint(ABCBloomConstraint):
     """
     This constrained repeats EOS token if it was generated on the previous step.

+ 28 - 0
tests/test_full_model.py

@@ -3,6 +3,7 @@ import torch
 import transformers
 from hivemind import get_logger, use_hivemind_log_handler
 from test_utils import *
+from transformers.generation_utils import BeamSearchScorer
 
 from src.bloom.model import BloomForCausalLM
 from src.client.remote_model import DistributedBloomForCausalLM
@@ -89,3 +90,30 @@ def test_greedy_generation(max_new_tokens=4):
     assert torch.allclose(
         remote_outputs_batch, hf_outputs_batch
     ), "Greedy search are not identical to HF in multibatch mode"
+
+
+@pytest.mark.forked
+def test_beam_search_generation(max_new_tokens=4, num_beams=2):
+    tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
+    model = DistributedBloomForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+    text = "A cat sat on a mat"
+    inputs = tokenizer(text, return_tensors="pt")["input_ids"]
+    remote_outputs = model.generate(
+        inputs,
+        max_new_tokens=max_new_tokens,
+        num_beams=num_beams,
+    )
+    beam_scorer = BeamSearchScorer(
+        batch_size=inputs.size(0),
+        num_beams=num_beams,
+        device=inputs.device,
+        length_penalty=0,
+        do_early_stopping=False,
+    )
+    hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
+    hf_outputs = BloomForCausalLM.beam_search(
+        model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
+    )
+    assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"