|
@@ -3,7 +3,13 @@ from typing import List, Optional
|
|
|
import torch
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
-from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
|
|
|
+from src.utils.generation_algorithms import (
|
|
|
+ BeamSearchAlgorithm,
|
|
|
+ DecodingAlgorithm,
|
|
|
+ GreedyAlgorithm,
|
|
|
+ NucleusAlgorithm,
|
|
|
+ TopKAlgorithm
|
|
|
+)
|
|
|
from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
|
|
|
|
|
|
|
|
@@ -25,6 +31,7 @@ class RemoteGenerationMixin:
|
|
|
temperature: float = 1.0,
|
|
|
top_k: Optional[int] = None,
|
|
|
top_p: Optional[float] = None,
|
|
|
+ num_beams: Optional[int] = None,
|
|
|
bos_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
pad_token_id: Optional[int] = None,
|
|
@@ -78,14 +85,19 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
if inputs is None:
|
|
|
assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
|
- inputs = torch.tensor([[bos_token_id]])
|
|
|
+ inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
|
|
|
|
|
if decoding_algorithm is None:
|
|
|
if do_sample:
|
|
|
decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
|
|
|
+ elif num_beams is not None and num_beams > 1:
|
|
|
+ decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=inputs.size(0))
|
|
|
else:
|
|
|
decoding_algorithm = GreedyAlgorithm()
|
|
|
|
|
|
+ if num_beams > 1:
|
|
|
+ inputs = torch.cat([inputs] * num_beams, dim=0)
|
|
|
+
|
|
|
constraints = self._get_constraints(
|
|
|
inputs=inputs,
|
|
|
eos_token_id=eos_token_id,
|
|
@@ -198,12 +210,38 @@ class RemoteGenerationMixin:
|
|
|
def beam_search(
|
|
|
self,
|
|
|
input_ids: torch.LongTensor,
|
|
|
+ num_beams: int = 1,
|
|
|
max_length: Optional[int] = None,
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
**model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
+ """
|
|
|
+ Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
|
|
+
|
|
|
+ :param input_ids: The input tokens to the model.
|
|
|
+ :param num_beams: The number of beams to use.
|
|
|
+ :param max_length: The maximum length of the sequence to generate.
|
|
|
+ :param pad_token_id: The id of the padding token.
|
|
|
+ :param eos_token_id: The id of the end of sentence token.
|
|
|
+ :param provided_constraints: A list of constraints to use.
|
|
|
+ :param: model_kwargs: Additional kwargs to pass to the model.
|
|
|
+ """
|
|
|
+ decoding_algorithm = BeamSearchAlgorithm(
|
|
|
+ num_beams=num_beams,
|
|
|
+ bath_size=input_ids.size(0),
|
|
|
+ )
|
|
|
+ return self.generate(
|
|
|
+ inputs=input_ids,
|
|
|
+ num_beams=num_beams,
|
|
|
+ max_new_tokens=max_length,
|
|
|
+ pad_token_id=pad_token_id,
|
|
|
+ eos_token_id=eos_token_id,
|
|
|
+ decoding_algorithm=decoding_algorithm,
|
|
|
+ provided_constraints=provided_constraints,
|
|
|
+ **model_kwargs,
|
|
|
+ )
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def beam_sample(
|