|
@@ -44,6 +44,7 @@ class RemoteGenerationMixin:
|
|
|
def generate(
|
|
|
self,
|
|
|
inputs: Optional[torch.Tensor] = None,
|
|
|
+ *,
|
|
|
do_sample: Optional[bool] = None,
|
|
|
temperature: float = 1.0,
|
|
|
top_k: Optional[int] = None,
|
|
@@ -57,9 +58,7 @@ class RemoteGenerationMixin:
|
|
|
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
num_return_sequences: Optional[int] = None,
|
|
|
- *,
|
|
|
session: Optional[InferenceSession] = None,
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head.
|
|
@@ -77,19 +76,9 @@ class RemoteGenerationMixin:
|
|
|
:param max_new_tokens: The maximum number of tokens to generate.
|
|
|
:param decoding_algorithm: The decoding algorithm to use.
|
|
|
:param provided_constraints: A list of constraints to use.
|
|
|
- :param model_kwargs: Additional arguments to pass to the model.
|
|
|
:param num_return_sequences: How many hypothesis from the beam will be in output.
|
|
|
"""
|
|
|
|
|
|
- assert (
|
|
|
- model_kwargs.get("logits_processor", None) is None
|
|
|
- ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
|
|
|
- assert (
|
|
|
- model_kwargs.get("logits_wrapper", None) is None
|
|
|
- ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
|
|
|
- assert (
|
|
|
- model_kwargs.get("stopping_criteria", None) is None
|
|
|
- ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
|
prefix_length = 0 if inputs is None else inputs.size(1)
|
|
|
prefix_length += self.config.pre_seq_len
|
|
|
|
|
@@ -226,7 +215,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head. Uses greedy search.
|
|
@@ -244,7 +232,6 @@ class RemoteGenerationMixin:
|
|
|
eos_token_id=eos_token_id,
|
|
|
decoding_algorithm=GreedyAlgorithm(),
|
|
|
provided_constraints=provided_constraints,
|
|
|
- **model_kwargs,
|
|
|
)
|
|
|
|
|
|
def sample(
|
|
@@ -257,7 +244,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
|
|
@@ -271,7 +257,6 @@ class RemoteGenerationMixin:
|
|
|
:param: pad_token_id: The id of the padding token.
|
|
|
:param: eos_token_id: The id of the end of sentence token.
|
|
|
:param: provided_constraints: A list of constraints to use.
|
|
|
- :param: model_kwargs: Additional kwargs to pass to the model.
|
|
|
"""
|
|
|
|
|
|
return self.generate(
|
|
@@ -281,7 +266,6 @@ class RemoteGenerationMixin:
|
|
|
eos_token_id=eos_token_id,
|
|
|
decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
|
|
|
provided_constraints=provided_constraints,
|
|
|
- **model_kwargs,
|
|
|
)
|
|
|
|
|
|
def beam_search(
|
|
@@ -292,7 +276,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head. Uses beam search.
|
|
@@ -303,7 +286,6 @@ class RemoteGenerationMixin:
|
|
|
:param pad_token_id: The id of the padding token.
|
|
|
:param eos_token_id: The id of the end of sentence token.
|
|
|
:param provided_constraints: A list of constraints to use.
|
|
|
- :param: model_kwargs: Additional kwargs to pass to the model.
|
|
|
"""
|
|
|
decoding_algorithm = BeamSearchAlgorithm(
|
|
|
num_beams=num_beams,
|
|
@@ -317,7 +299,6 @@ class RemoteGenerationMixin:
|
|
|
eos_token_id=eos_token_id,
|
|
|
decoding_algorithm=decoding_algorithm,
|
|
|
provided_constraints=provided_constraints,
|
|
|
- **model_kwargs,
|
|
|
)
|
|
|
|
|
|
def beam_sample(
|
|
@@ -327,7 +308,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
raise NotImplementedError
|
|
|
|
|
@@ -338,7 +318,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
- **model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
raise NotImplementedError
|
|
|
|