Artem Chumachenko před 2 roky
rodič
revize
95334833af

+ 2 - 3
src/client/remote_generation.py

@@ -168,11 +168,10 @@ class RemoteGenerationMixin:
                     break
 
         outputs = torch.cat(outputs, dim=-1)
-        
+
         if num_beams > 1:
             pre_return_idx = [
-                torch.arange(idx, num_return_sequences * batch_size, batch_size)
-                for idx in range(batch_size)
+                torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
             ]
             return_idx = torch.cat(pre_return_idx, dim=0)
             outputs = outputs[return_idx]

+ 5 - 8
src/utils/generation_algorithms.py

@@ -1,6 +1,6 @@
 from abc import ABC
+from heapq import heappop, heappush
 from typing import Tuple
-from heapq import heappush, heappop
 
 import torch
 
@@ -82,11 +82,11 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self.batch_size = batch_size
 
         self._beams = []
-    
+
     def __call__(self, logits: torch.Tensor):
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.log_softmax(sorted_logits, -1)
-        
+
         self._beams = [(beam[0], beam[1] % self.num_beams) for beam in self._beams]
         if len(self._beams) > 0:
             new_beams = []
@@ -97,10 +97,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
                     for hypo_idx in range(self.num_beams):
                         heappush(
                             new_beams,
-                            (
-                                new_beam[0] + probs[probs_idx, hypo_idx].item(),
-                                beam_idx * self.num_beams + hypo_idx
-                            )
+                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx),
                         )
                         if len(new_beams) > self.batch_size * self.num_beams:
                             heappop(new_beams)
@@ -109,7 +106,7 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
             for batch_idx in range(self.batch_size):
                 for beam_idx in range(self.num_beams):
                     self._beams.append((probs[batch_idx, beam_idx].item(), beam_idx))
-                    
+
         return_hypos = []
         return_tokens = []
         for batch_idx in range(self.batch_size):