|
@@ -81,40 +81,42 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
self._cur_num_beams = 1
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
- self._beams = []
|
|
|
+ self._batch_beams = [list() for _ in range(batch_size)]
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor):
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
probs = torch.log_softmax(sorted_logits, -1)
|
|
|
|
|
|
- self._beams = [(beam[0], beam[1] % self.num_beams) for beam in self._beams]
|
|
|
- if len(self._beams) > 0:
|
|
|
- new_beams = []
|
|
|
+ if len(self._batch_beams[0]) > 0:
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
- for beam_idx in range(self.num_beams):
|
|
|
+ new_beams = []
|
|
|
+ cur_beams = self._batch_beams[batch_idx]
|
|
|
+ for beam_idx in range(len(cur_beams)):
|
|
|
probs_idx = batch_idx + beam_idx * self.batch_size
|
|
|
- new_beam = self._beams[probs_idx]
|
|
|
+ new_beam = cur_beams[beam_idx]
|
|
|
for hypo_idx in range(self.num_beams):
|
|
|
- heappush(
|
|
|
- new_beams,
|
|
|
- (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx),
|
|
|
+ new_beams.append(
|
|
|
+ (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
|
|
|
)
|
|
|
- if len(new_beams) > self.batch_size * self.num_beams:
|
|
|
- heappop(new_beams)
|
|
|
- self._beams = new_beams
|
|
|
+ self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
|
|
|
else:
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
for beam_idx in range(self.num_beams):
|
|
|
- self._beams.append((probs[batch_idx, beam_idx].item(), beam_idx))
|
|
|
+ self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
|
|
|
|
|
|
return_hypos = []
|
|
|
return_tokens = []
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
- for beam_idx in range(self.num_beams):
|
|
|
- beam = self._beams[batch_idx + beam_idx * self.batch_size]
|
|
|
- hypo_idx = beam[1] // self.num_beams
|
|
|
+ cur_beam = self._batch_beams[batch_idx]
|
|
|
+ return_hypos.append(list())
|
|
|
+ return_tokens.append(list())
|
|
|
+ for beam in cur_beam:
|
|
|
+ beam_idx = beam[1] // self.num_beams
|
|
|
+ hypo_idx = batch_idx + beam_idx * self.batch_size
|
|
|
token_idx = beam[1] % self.num_beams
|
|
|
- return_hypos.append(hypo_idx)
|
|
|
- return_tokens.append([sorted_indices[hypo_idx, token_idx].item()])
|
|
|
+ return_hypos[-1].append(hypo_idx)
|
|
|
+ return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
|
|
|
+ return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
|
|
|
+ return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
|
|
|
|
|
|
return torch.tensor(return_tokens), torch.tensor(return_hypos)
|