|
@@ -10,7 +10,7 @@ from src.utils.generation_algorithms import (
|
|
|
NucleusAlgorithm,
|
|
|
TopKAlgorithm
|
|
|
)
|
|
|
-from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
|
|
|
+from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
|
|
|
|
|
|
|
|
|
class RemoteGenerationMixin:
|
|
@@ -19,8 +19,9 @@ class RemoteGenerationMixin:
|
|
|
The class exposes can be used for:
|
|
|
- *greedy decoding*.
|
|
|
- *multinomial sampling*.
|
|
|
+ - *beam-search decoding*
|
|
|
|
|
|
- This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
|
|
|
+ This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
|
|
|
"""
|
|
|
|
|
|
@torch.no_grad()
|
|
@@ -31,7 +32,7 @@ class RemoteGenerationMixin:
|
|
|
temperature: float = 1.0,
|
|
|
top_k: Optional[int] = None,
|
|
|
top_p: Optional[float] = None,
|
|
|
- num_beams: Optional[int] = None,
|
|
|
+ num_beams: Optional[int] = 1,
|
|
|
bos_token_id: Optional[int] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
pad_token_id: Optional[int] = None,
|
|
@@ -49,6 +50,7 @@ class RemoteGenerationMixin:
|
|
|
:param temperature: The temperature to use for sampling.
|
|
|
:param top_k: The number of results to return.
|
|
|
:param top_p: The cumulative probability of results to return.
|
|
|
+ :param num_beams: The number of beams to use for beam search.
|
|
|
:param bos_token_id: The id of the beginning of sentence token.
|
|
|
:param eos_token_id: The id of the end of sentence token.
|
|
|
:param pad_token_id: The id of the padding token.
|
|
@@ -102,12 +104,13 @@ class RemoteGenerationMixin:
|
|
|
inputs=inputs,
|
|
|
eos_token_id=eos_token_id,
|
|
|
pad_token_id=pad_token_id,
|
|
|
- max_new_tokens=max_new_tokens,
|
|
|
provided_constraints=provided_constraints,
|
|
|
)
|
|
|
|
|
|
with self.transformer.h.inference_session(max_length=max_length) as sess:
|
|
|
outputs = []
|
|
|
+ # Find samples with padded inputs.
|
|
|
+ # They will be changed before all of the samples have right length.
|
|
|
if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
|
|
|
outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
|
|
|
else:
|
|
@@ -129,13 +132,15 @@ class RemoteGenerationMixin:
|
|
|
for constraint in constraints:
|
|
|
lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
|
|
|
last_token_id, hypo_ids = decoding_algorithm(lm_logits)
|
|
|
- if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
|
|
|
+
|
|
|
+ # If samples have padded, so changes only them
|
|
|
+ if seq_idx < inputs.size(1):
|
|
|
pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
|
|
|
last_token_id = (~pad_token_mask) * inputs[
|
|
|
:, seq_idx : seq_idx + 1
|
|
|
] + pad_token_mask * last_token_id
|
|
|
|
|
|
- if torch.all(last_token_id == eos_token_id):
|
|
|
+ if torch.all(last_token_id == eos_token_id) or len(outputs) >= max_new_tokens:
|
|
|
break
|
|
|
|
|
|
outputs.append(last_token_id)
|
|
@@ -242,7 +247,6 @@ class RemoteGenerationMixin:
|
|
|
provided_constraints=provided_constraints,
|
|
|
**model_kwargs,
|
|
|
)
|
|
|
- raise NotImplementedError
|
|
|
|
|
|
def beam_sample(
|
|
|
self,
|
|
@@ -284,12 +288,9 @@ class RemoteGenerationMixin:
|
|
|
inputs: Optional[torch.Tensor] = None,
|
|
|
eos_token_id: Optional[int] = None,
|
|
|
pad_token_id: Optional[int] = None,
|
|
|
- max_new_tokens: Optional[int] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
) -> List[ABCBloomConstraint]:
|
|
|
constraints = []
|
|
|
constraints.extend(provided_constraints)
|
|
|
- if max_new_tokens is not None:
|
|
|
- constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
|
|
|
constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
|
|
|
return constraints
|