|
@@ -126,6 +126,14 @@ class RemoteGenerationMixin:
|
|
|
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
)
|
|
|
|
|
|
+ if num_return_sequences is None:
|
|
|
+ num_return_sequences = 1
|
|
|
+
|
|
|
+ assert num_return_sequences <= num_beams, (
|
|
|
+ f"You want more sequences that beam will have."
|
|
|
+ " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
+ )
|
|
|
+
|
|
|
constraints = self._get_constraints(
|
|
|
inputs=inputs,
|
|
|
eos_token_id=eos_token_id,
|
|
@@ -144,6 +152,7 @@ class RemoteGenerationMixin:
|
|
|
last_token_id = None
|
|
|
seq_idx = outputs[0].size(1)
|
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
|
+ hypo_ids_map = dict()
|
|
|
while True:
|
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
|
intermediate_prompts = None
|
|
@@ -177,6 +186,13 @@ class RemoteGenerationMixin:
|
|
|
for i in range(len(outputs), 1, -1):
|
|
|
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
|
|
|
|
+ if num_beams > 1:
|
|
|
+ hypo_ids_map[len(outputs)] = hypo_ids
|
|
|
+ cur_hypo_ids = torch.tensor(hypo_ids)
|
|
|
+ for i in range(len(outputs), 1, -1):
|
|
|
+ outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
|
|
|
+ cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
|
|
|
+
|
|
|
outputs.append(last_token_id)
|
|
|
seq_idx += 1
|
|
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|