Artem Chumachenko 2 rokov pred
rodič
commit
0328fbce84

+ 16 - 0
src/client/remote_generation.py

@@ -126,6 +126,14 @@ class RemoteGenerationMixin:
             " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
         )
 
+        if num_return_sequences is None:
+            num_return_sequences = 1
+
+        assert num_return_sequences <= num_beams, (
+            f"You want more sequences that beam will have."
+            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
+        )
+
         constraints = self._get_constraints(
             inputs=inputs,
             eos_token_id=eos_token_id,
@@ -144,6 +152,7 @@ class RemoteGenerationMixin:
             last_token_id = None
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
+            hypo_ids_map = dict()
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
@@ -177,6 +186,13 @@ class RemoteGenerationMixin:
                     for i in range(len(outputs), 1, -1):
                         outputs[i - 1] = outputs[i - 1][hypo_ids]
 
+                if num_beams > 1:
+                    hypo_ids_map[len(outputs)] = hypo_ids
+                    cur_hypo_ids = torch.tensor(hypo_ids)
+                    for i in range(len(outputs), 1, -1):
+                        outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
+                        cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
+
                 outputs.append(last_token_id)
                 seq_idx += 1
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:

+ 1 - 0
src/utils/generation_algorithms.py

@@ -1,5 +1,6 @@
 from abc import ABC
 from typing import Tuple
+from heapq import heappush, heappop
 
 import torch