Artem Chumachenko 2 år sedan
förälder
incheckning
ab141245af
1 ändrade filer med 1 tillägg och 3 borttagningar
  1. 1 3
      src/client/remote_generation.py

+ 1 - 3
src/client/remote_generation.py

@@ -144,7 +144,6 @@ class RemoteGenerationMixin:
             last_token_id = None
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
-            hypo_ids_map = dict()
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
@@ -176,8 +175,7 @@ class RemoteGenerationMixin:
                     hypo_ids_map[len(outputs)] = hypo_ids
                     cur_hypo_ids = torch.tensor(hypo_ids)
                     for i in range(len(outputs), 1, -1):
-                        outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
-                        cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
+                        outputs[i - 1] = outputs[i - 1][hypo_ids]
 
                 outputs.append(last_token_id)
                 seq_idx += 1