|
@@ -13,9 +13,10 @@ class RemoteGenerationMixin:
|
|
|
The class exposes can be used for:
|
|
|
- *greedy decoding*.
|
|
|
- *multinomial sampling*.
|
|
|
-
|
|
|
+
|
|
|
This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
|
|
|
"""
|
|
|
+
|
|
|
def generate(
|
|
|
self,
|
|
|
inputs: Optional[torch.Tensor] = None,
|
|
@@ -33,7 +34,7 @@ class RemoteGenerationMixin:
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head.
|
|
|
-
|
|
|
+
|
|
|
:param inputs: The input tokens to the model.
|
|
|
:param do_sample: Whether to sample from the model predictions or take the argmax.
|
|
|
:param temperature: The temperature to use for sampling.
|
|
@@ -48,9 +49,15 @@ class RemoteGenerationMixin:
|
|
|
:param model_kwargs: Additional arguments to pass to the model.
|
|
|
"""
|
|
|
|
|
|
- assert model_kwargs.get("logits_processor", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
|
|
|
- assert model_kwargs.get("logits_wrapper", None) is None, "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
|
|
|
- assert model_kwargs.get("stopping_criteria", None) is None, "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
|
+ assert (
|
|
|
+ model_kwargs.get("logits_processor", None) is None
|
|
|
+ ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
|
|
|
+ assert (
|
|
|
+ model_kwargs.get("logits_wrapper", None) is None
|
|
|
+ ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
|
|
|
+ assert (
|
|
|
+ model_kwargs.get("stopping_criteria", None) is None
|
|
|
+ ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
|
|
|
|
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
@@ -68,16 +75,16 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
constraints = self._get_constraints(
|
|
|
inputs=inputs,
|
|
|
- eos_token_id=eos_token_id,
|
|
|
- pad_token_id=pad_token_id,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
+ eos_token_id=eos_token_id,
|
|
|
+ pad_token_id=pad_token_id,
|
|
|
+ max_new_tokens=max_new_tokens,
|
|
|
provided_constraints=provided_constraints,
|
|
|
)
|
|
|
|
|
|
with self.transformer.h.inference_session() as sess:
|
|
|
outputs = []
|
|
|
- if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
|
|
- outputs += [inputs[:, :inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
|
|
+ if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
|
|
+ outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
|
|
else:
|
|
|
outputs += [inputs]
|
|
|
last_token_id = None
|
|
@@ -93,9 +100,11 @@ class RemoteGenerationMixin:
|
|
|
for constraint in constraints:
|
|
|
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
|
|
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
|
|
- if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
|
|
|
- pad_token_mask = inputs[:, seq_idx:seq_idx + 1] == pad_token_id
|
|
|
- last_token_id = (~pad_token_mask) * inputs[:, seq_idx:seq_idx + 1] + pad_token_mask * last_token_id
|
|
|
+ if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
|
|
|
+ pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
|
|
+ last_token_id = (~pad_token_mask) * inputs[
|
|
|
+ :, seq_idx : seq_idx + 1
|
|
|
+ ] + pad_token_mask * last_token_id
|
|
|
|
|
|
if torch.all(last_token_id == eos_token_id):
|
|
|
break
|
|
@@ -147,7 +156,7 @@ class RemoteGenerationMixin:
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
|
Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
|
|
|
-
|
|
|
+
|
|
|
:param: input_ids: The input tokens to the model.
|
|
|
:param: temperature: The temperature to use for sampling.
|
|
|
:param: top_k: The number of samples to use for top_k sampling.
|
|
@@ -229,4 +238,3 @@ class RemoteGenerationMixin:
|
|
|
constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
|
|
|
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
|
|
return constraints
|
|
|
-
|