Artem Chumachenko il y a 2 ans
Parent
commit
ce22b6a47b
2 fichiers modifiés avec 12 ajouts et 7 suppressions
  1. 1 1
      src/client/remote_generation.py
  2. 11 6
      src/utils/generation_algorithms.py

+ 1 - 1
src/client/remote_generation.py

@@ -8,7 +8,7 @@ from src.utils.generation_algorithms import (
     DecodingAlgorithm,
     DecodingAlgorithm,
     GreedyAlgorithm,
     GreedyAlgorithm,
     NucleusAlgorithm,
     NucleusAlgorithm,
-    TopKAlgorithm
+    TopKAlgorithm,
 )
 )
 from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint
 
 

+ 11 - 6
src/utils/generation_algorithms.py

@@ -80,12 +80,17 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._cur_num_beams = 1
         self._cur_num_beams = 1
         self.batch_size = batch_size
         self.batch_size = batch_size
 
 
-        self._logits = torch.zeros((self.batch_size, self._cur_num_beams,))
-    
+        self._logits = torch.zeros(
+            (
+                self.batch_size,
+                self._cur_num_beams,
+            )
+        )
+
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
         probs = torch.softmax(sorted_logits, -1)
         probs = torch.softmax(sorted_logits, -1)
-        
+
         new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
         new_logits = torch.cat([self._logits] * self.num_beams, dim=-1)
         for batch_idx in range(self.batch_size):
         for batch_idx in range(self.batch_size):
             for cur_beam_idx in range(self._cur_num_beams):
             for cur_beam_idx in range(self._cur_num_beams):
@@ -95,9 +100,9 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._cur_num_beams = self.num_beams
         self._cur_num_beams = self.num_beams
 
 
         new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
         new_sorted_logits, new_sorted_indices = torch.sort(new_logits, descending=True, dim=-1)
-        new_sorted_indices = new_sorted_indices[:, :self.num_beams].T.flatten()
-        self._logits = new_sorted_logits[:, :self.num_beams]
+        new_sorted_indices = new_sorted_indices[:, : self.num_beams].T.flatten()
+        self._logits = new_sorted_logits[:, : self.num_beams]
         result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
         result_tokens = sorted_indices[torch.arange(self.num_beams * self.batch_size), new_sorted_indices]
-        result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode='floor')
+        result_hypos = torch.div(new_sorted_indices, self.num_beams, rounding_mode="floor")
 
 
         return result_tokens.unsqueeze(-1), result_hypos
         return result_tokens.unsqueeze(-1), result_hypos