Artem Chumachenko 2 年之前
父節點
當前提交
ccb55bcfe2
共有 1 個文件被更改,包括 0 次插入30 次删除
  1. 0 30
      src/client/remote_generation.py

+ 0 - 30
src/client/remote_generation.py

@@ -118,22 +118,6 @@ class RemoteGenerationMixin:
             " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
             " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
         )
         )
 
 
-        if num_return_sequences is None:
-            num_return_sequences = 1
-
-        assert num_return_sequences <= num_beams, (
-            f"You want more sequences that beam will have."
-            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
-        )
-
-        if num_return_sequences is None:
-            num_return_sequences = 1
-
-        assert num_return_sequences <= num_beams, (
-            f"You want more sequences that beam will have."
-            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
-        )
-
         constraints = self._get_constraints(
         constraints = self._get_constraints(
             inputs=inputs,
             inputs=inputs,
             eos_token_id=eos_token_id,
             eos_token_id=eos_token_id,
@@ -152,7 +136,6 @@ class RemoteGenerationMixin:
             last_token_id = None
             last_token_id = None
             seq_idx = outputs[0].size(1)
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
             hypo_ids = torch.arange(outputs[0].size(0))
-            hypo_ids_map = dict()
             while True:
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 embs = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
                 intermediate_prompts = None
@@ -180,19 +163,6 @@ class RemoteGenerationMixin:
                     for i in range(len(outputs), 1, -1):
                     for i in range(len(outputs), 1, -1):
                         outputs[i - 1] = outputs[i - 1][hypo_ids]
                         outputs[i - 1] = outputs[i - 1][hypo_ids]
 
 
-                if num_beams > 1:
-                    hypo_ids_map[len(outputs)] = hypo_ids
-                    cur_hypo_ids = torch.tensor(hypo_ids)
-                    for i in range(len(outputs), 1, -1):
-                        outputs[i - 1] = outputs[i - 1][hypo_ids]
-
-                if num_beams > 1:
-                    hypo_ids_map[len(outputs)] = hypo_ids
-                    cur_hypo_ids = torch.tensor(hypo_ids)
-                    for i in range(len(outputs), 1, -1):
-                        outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
-                        cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
-
                 outputs.append(last_token_id)
                 outputs.append(last_token_id)
                 seq_idx += 1
                 seq_idx += 1
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens: