|
@@ -80,29 +80,33 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
self._cur_num_beams = 1
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
- self._logits = torch.zeros(
|
|
|
- (
|
|
|
- self.batch_size,
|
|
|
- self._cur_num_beams,
|
|
|
- )
|
|
|
- )
|
|
|
+ self._beams = []
|
|
|
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
- probs = torch.softmax(sorted_logits, -1)
|
|
|
-
|
|
|
- new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
|
|
|
+ probs = torch.log_softmax(sorted_logits, -1)
|
|
|
+
|
|
|
+ if len(self._beams) > 0:
|
|
|
+ new_beams = []
|
|
|
+ for batch_idx in range(self.batch_size):
|
|
|
+ for beam_idx in range(self.num_beams):
|
|
|
+ new_beam = self._beams[beam_idx]
|
|
|
+ for hypo_idx in range(self.num_beams):
|
|
|
+ probs_idx = batch_idx + beam_idx * self.batch_size
|
|
|
+ new_beams.append((beam_idx, new_beam[1] + probs[probs_idx, hypo_idx].item()))
|
|
|
+ new_beams = sorted(new_beams, key=lambda x: x[1], reverse=True)
|
|
|
+ self._beams = new_beams[: self.batch_size * self.num_beams]
|
|
|
+ else:
|
|
|
+ for batch_idx in range(self.batch_size):
|
|
|
+ for beam_idx in range(self.num_beams):
|
|
|
+ self._beams.append((beam_idx, probs[batch_idx, beam_idx].item()))
|
|
|
+
|
|
|
+ return_hypos = []
|
|
|
+ return_tokens = []
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
- for cur_beam_idx in range(self._cur_num_beams):
|
|
|
- for new_beam_idx in range(self.num_beams):
|
|
|
- logit = probs[cur_beam_idx * self.batch_size + batch_idx, new_beam_idx]
|
|
|
- new_logits[batch_idx, cur_beam_idx * self.num_beams + new_beam_idx] += logit
|
|
|
- self._cur_num_beams = self.num_beams
|
|
|
-
|
|
|
- new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
|
|
|
- new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
|
|
|
- self._logits = new_sorted_logits[:, : self.num_beams]
|
|
|
- result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
|
|
|
- result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
|
|
|
-
|
|
|
- return result_tokens.unsqueeze(-1), result_hypos
|
|
|
+ for beam_idx in range(self.num_beams):
|
|
|
+ hypo_idx = batch_idx + beam_idx * self.batch_size
|
|
|
+ return_hypos.append(self._beams[hypo_idx][0])
|
|
|
+ return_tokens.append([sorted_indices[batch_idx, beam_idx].item()])
|
|
|
+
|
|
|
+ return torch.tensor(return_tokens), torch.tensor(return_hypos)
|