Artem Chumachenko 2 éve
szülő
commit
00c6828672
2 módosított fájl, 30 hozzáadás és 27 törlés
  1. 10 9
      src/client/remote_generation.py
  2. 20 18
      src/utils/generation_algorithms.py

+ 10 - 9
src/client/remote_generation.py

@@ -1,7 +1,7 @@
 from typing import List, Optional
 
 import torch
-import torch.nn.functional as F
+from hivemind import get_logger
 
 from src.utils.generation_algorithms import (
     BeamSearchAlgorithm,
@@ -12,6 +12,8 @@ from src.utils.generation_algorithms import (
 )
 from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 
+logger = get_logger(__file__)
+
 
 class RemoteGenerationMixin:
     """
@@ -102,6 +104,11 @@ class RemoteGenerationMixin:
 
         if num_beams > 1:
             inputs = torch.cat([inputs] * num_beams, dim=0)
+            if batch_size > 1:
+                # TODO: resolve padding problem
+                logger.warning(
+                    f"You set batch_size {batch_size} within beam search generation. Be carefull, results on sequences with different length may be padded wrong way"
+                )
 
         if num_return_sequences is None:
             num_return_sequences = 1
@@ -129,7 +136,6 @@ class RemoteGenerationMixin:
             last_token_id = None
             seq_idx = outputs[0].size(1)
             hypo_ids = torch.arange(outputs[0].size(0))
-            hypo_ids_map = dict()
             while True:
                 embs = self.transformer.word_embeddings(outputs[-1])
                 intermediate_prompts = None
@@ -152,15 +158,10 @@ class RemoteGenerationMixin:
                         :, seq_idx : seq_idx + 1
                     ] + pad_token_mask * last_token_id
 
+                # TODO: refactor outputs
                 if num_beams > 1:
-                    outputs[-1] = outputs[-1][hypo_ids]
-
-                if num_beams > 1:
-                    hypo_ids_map[len(outputs)] = hypo_ids
-                    cur_hypo_ids = torch.tensor(hypo_ids)
                     for i in range(len(outputs), 1, -1):
-                        outputs[i - 1] = outputs[i - 1][cur_hypo_ids]
-                        cur_hypo_ids = hypo_ids[hypo_ids_map[i]]
+                        outputs[i - 1] = outputs[i - 1][hypo_ids]
 
                 outputs.append(last_token_id)
                 seq_idx += 1

+ 20 - 18
src/utils/generation_algorithms.py

@@ -81,40 +81,42 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._cur_num_beams = 1
         self.batch_size = batch_size
 
-        self._beams = []
+        self._batch_beams = [list() for _ in range(batch_size)]
 
     def __call__(self, logits: torch.Tensor):
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.log_softmax(sorted_logits, -1)
 
-        self._beams = [(beam[0], beam[1] % self.num_beams) for beam in self._beams]
-        if len(self._beams) > 0:
-            new_beams = []
+        if len(self._batch_beams[0]) > 0:
             for batch_idx in range(self.batch_size):
-                for beam_idx in range(self.num_beams):
+                new_beams = []
+                cur_beams = self._batch_beams[batch_idx]
+                for beam_idx in range(len(cur_beams)):
                     probs_idx = batch_idx + beam_idx * self.batch_size
-                    new_beam = self._beams[probs_idx]
+                    new_beam = cur_beams[beam_idx]
                     for hypo_idx in range(self.num_beams):
-                        heappush(
-                            new_beams,
-                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx),
+                        new_beams.append(
+                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
                         )
-                        if len(new_beams) > self.batch_size * self.num_beams:
-                            heappop(new_beams)
-            self._beams = new_beams
+                self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
         else:
             for batch_idx in range(self.batch_size):
                 for beam_idx in range(self.num_beams):
-                    self._beams.append((probs[batch_idx, beam_idx].item(), beam_idx))
+                    self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
 
         return_hypos = []
         return_tokens = []
         for batch_idx in range(self.batch_size):
-            for beam_idx in range(self.num_beams):
-                beam = self._beams[batch_idx + beam_idx * self.batch_size]
-                hypo_idx = beam[1] // self.num_beams
+            cur_beam = self._batch_beams[batch_idx]
+            return_hypos.append(list())
+            return_tokens.append(list())
+            for beam in cur_beam:
+                beam_idx = beam[1] // self.num_beams
+                hypo_idx = batch_idx + beam_idx * self.batch_size
                 token_idx = beam[1] % self.num_beams
-                return_hypos.append(hypo_idx)
-                return_tokens.append([sorted_indices[hypo_idx, token_idx].item()])
+                return_hypos[-1].append(hypo_idx)
+                return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
+        return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
+        return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
 
         return torch.tensor(return_tokens), torch.tensor(return_hypos)