Browse Source

Add vectorized version of beam_search

Artem Chumachenko 2 years ago
parent
commit
c242232c52
1 changed files with 32 additions and 30 deletions
  1. 32 30
      src/petals/utils/generation_algorithms.py

+ 32 - 30
src/petals/utils/generation_algorithms.py

@@ -80,42 +80,44 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._cur_num_beams = 1
         self._cur_num_beams = 1
         self.batch_size = batch_size
         self.batch_size = batch_size
 
 
-        self._batch_beams = [list() for _ in range(batch_size)]
+        self._batch_beams = torch.zeros((batch_size, num_beams))
 
 
-    def __call__(self, logits: torch.Tensor):
+    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.log_softmax(sorted_logits, -1)
         probs = torch.log_softmax(sorted_logits, -1)
 
 
-        if len(self._batch_beams[0]) > 0:
-            for batch_idx in range(self.batch_size):
-                new_beams = []
-                cur_beams = self._batch_beams[batch_idx]
-                for beam_idx in range(len(cur_beams)):
-                    probs_idx = batch_idx + beam_idx * self.batch_size
-                    new_beam = cur_beams[beam_idx]
-                    for hypo_idx in range(self.num_beams):
-                        new_beams.append(
-                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
-                        )
-                self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
+        hypo_ids = None
+        if self._cur_num_beams > 1:
+            permuted_indexes = torch.cat(
+                [torch.arange(0, self.num_beams) * self.batch_size + i for i in range(self.batch_size)], dim=0
+            )
+            probs = probs[:, : self.num_beams][permuted_indexes]
+            probs = probs.view(self.batch_size, self.num_beams, self.num_beams)
+            self._batch_beams = self._batch_beams[:, :, None] + probs
+            self._batch_beams = self._batch_beams.view(self.batch_size, -1)
+            sorted_batch_beams, sorted_hypo_ids = torch.sort(self._batch_beams, descending=True, dim=-1)
+            self._batch_beams = sorted_batch_beams[:, : self.num_beams]
+            hypo_ids = sorted_hypo_ids[:, : self.num_beams]
         else:
         else:
-            for batch_idx in range(self.batch_size):
-                for beam_idx in range(self.num_beams):
-                    self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
+            self._batch_beams = probs[: self.batch_size, : self.num_beams]
+            self._cur_num_beams = self.num_beams
+            hypo_ids = torch.tile(
+                torch.arange(self.num_beams),
+                (self.batch_size, 1),
+            )
 
 
         return_hypos = []
         return_hypos = []
         return_tokens = []
         return_tokens = []
         for batch_idx in range(self.batch_size):
         for batch_idx in range(self.batch_size):
-            cur_beam = self._batch_beams[batch_idx]
-            return_hypos.append(list())
-            return_tokens.append(list())
-            for beam in cur_beam:
-                beam_idx = beam[1] // self.num_beams
-                hypo_idx = batch_idx + beam_idx * self.batch_size
-                token_idx = beam[1] % self.num_beams
-                return_hypos[-1].append(hypo_idx)
-                return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
-        return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
-        return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
-
-        return torch.tensor(return_tokens), torch.tensor(return_hypos)
+            cur_beam = hypo_ids[batch_idx]
+            hypo_idx = batch_idx + torch.floor_divide(cur_beam, self.num_beams) * self.batch_size
+            return_hypos.append(hypo_idx)
+            return_tokens.append(sorted_indices[hypo_idx, cur_beam % self.num_beams].unsqueeze(-1))
+
+        return_indexes = torch.cat(
+            [torch.arange(0, self.batch_size) * self.num_beams + i for i in range(self.num_beams)], dim=0
+        )
+        return_tokens = torch.cat(return_tokens, 0)
+        return_hypos = torch.cat(return_hypos, 0)
+
+        return return_tokens[return_indexes], return_hypos[return_indexes]