|
@@ -1,6 +1,6 @@
|
|
|
from abc import ABC
|
|
|
+from heapq import heappop, heappush
|
|
|
from typing import Tuple
|
|
|
-from heapq import heappush, heappop
|
|
|
|
|
|
import torch
|
|
|
|
|
@@ -82,11 +82,11 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
self._beams = []
|
|
|
-
|
|
|
+
|
|
|
def __call__(self, logits: torch.Tensor):
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
probs = torch.log_softmax(sorted_logits, -1)
|
|
|
-
|
|
|
+
|
|
|
self._beams = [(beam[0], beam[1] % self.num_beams) for beam in self._beams]
|
|
|
if len(self._beams) > 0:
|
|
|
new_beams = []
|
|
@@ -97,10 +97,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
for hypo_idx in range(self.num_beams):
|
|
|
heappush(
|
|
|
new_beams,
|
|
|
- (
|
|
|
- new_beam[0] + probs[probs_idx, hypo_idx].item(),
|
|
|
- beam_idx * self.num_beams + hypo_idx
|
|
|
- )
|
|
|
+ (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx),
|
|
|
)
|
|
|
if len(new_beams) > self.batch_size * self.num_beams:
|
|
|
heappop(new_beams)
|
|
@@ -109,7 +106,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
for beam_idx in range(self.num_beams):
|
|
|
self._beams.append((probs[batch_idx, beam_idx].item(), beam_idx))
|
|
|
-
|
|
|
+
|
|
|
return_hypos = []
|
|
|
return_tokens = []
|
|
|
for batch_idx in range(self.batch_size):
|