|
@@ -118,22 +118,6 @@ class RemoteGenerationMixin:
|
|
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
" Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
)
|
|
)
|
|
|
|
|
|
- if num_return_sequences is None:
|
|
|
|
- num_return_sequences = 1
|
|
|
|
-
|
|
|
|
- assert num_return_sequences <= num_beams, (
|
|
|
|
- f"You want more sequences that beam will have."
|
|
|
|
- " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if num_return_sequences is None:
|
|
|
|
- num_return_sequences = 1
|
|
|
|
-
|
|
|
|
- assert num_return_sequences <= num_beams, (
|
|
|
|
- f"You want more sequences that beam will have."
|
|
|
|
- " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
constraints = self._get_constraints(
|
|
constraints = self._get_constraints(
|
|
inputs=inputs,
|
|
inputs=inputs,
|
|
eos_token_id=eos_token_id,
|
|
eos_token_id=eos_token_id,
|
|
@@ -152,7 +136,6 @@ class RemoteGenerationMixin:
|
|
last_token_id = None
|
|
last_token_id = None
|
|
seq_idx = outputs[0].size(1)
|
|
seq_idx = outputs[0].size(1)
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
hypo_ids = torch.arange(outputs[0].size(0))
|
|
- hypo_ids_map = dict()
|
|
|
|
while True:
|
|
while True:
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
embs = self.transformer.word_embeddings(outputs[-1])
|
|
intermediate_prompts = None
|
|
intermediate_prompts = None
|
|
@@ -180,19 +163,6 @@ class RemoteGenerationMixin:
|
|
for i in range(len(outputs), 1, -1):
|
|
for i in range(len(outputs), 1, -1):
|
|
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
|
|
|
|
- if num_beams > 1:
|
|
|
|
- hypo_ids_map[len(outputs)] = hypo_ids
|
|
|
|
- cur_hypo_ids = torch.tensor(hypo_ids)
|
|
|
|
- for i in range(len(outputs), 1, -1):
|
|
|
|
- outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
|
|
-
|
|
|
|
- if num_beams > 1:
|
|
|
|
- hypo_ids_map[len(outputs)] = hypo_ids
|
|
|
|
- cur_hypo_ids = torch.tensor(hypo_ids)
|
|
|
|
- for i in range(len(outputs), 1, -1):
|
|
|
|
- outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
|
|
|
|
- cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
|
|
|
|
-
|
|
|
|
outputs.append(last_token_id)
|
|
outputs.append(last_token_id)
|
|
seq_idx += 1
|
|
seq_idx += 1
|
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|