|
@@ -80,12 +80,17 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
self._cur_num_beams = 1
|
|
|
self.batch_size = batch_size
|
|
|
|
|
|
- self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
|
|
|
-
|
|
|
+ self._logits = torch.zeros(
|
|
|
+ (
|
|
|
+ self.batch_size,
|
|
|
+ self._cur_num_beams,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
|
probs = torch.softmax(sorted_logits, -1)
|
|
|
-
|
|
|
+
|
|
|
new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
|
|
|
for batch_idx in range(self.batch_size):
|
|
|
for cur_beam_idx in range(self._cur_num_beams):
|
|
@@ -95,9 +100,9 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
|
|
|
self._cur_num_beams = self.num_beams
|
|
|
|
|
|
new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
|
|
|
- new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
|
|
|
- self._logits = new_sorted_logits[:, :self.num_beams]
|
|
|
+ new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
|
|
|
+ self._logits = new_sorted_logits[:, : self.num_beams]
|
|
|
result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
|
|
|
- result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
|
|
|
+ result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
|
|
|
|
|
|
return result_tokens.unsqueeze(-1), result_hypos
|