فهرست منبع

Pack of fixes

Artem Chumachenko 2 سال پیش
والد
کامیت
b604e778ac
2فایلهای تغییر یافته به همراه15 افزوده شده و 28 حذف شده
  1. 2 1
      src/petals/client/inference_session.py
  2. 13 27
      src/petals/utils/generation_algorithms.py

+ 2 - 1
src/petals/client/inference_session.py

@@ -207,7 +207,7 @@ class InferenceSession:
         assert not self._closed and not self._chosen_spans
         return self
 
-    def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
+    def step(self, inputs: torch.Tensor, hypo_ids: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
         assert not self._closed
         if torch.is_grad_enabled():
             logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@@ -222,6 +222,7 @@ class InferenceSession:
         inputs_dtype = inputs.dtype
         inputs = inputs.cpu()
         prompts = prompts.cpu()
+        hypo_ids = hypo_ids.cpu()
 
         n_input_tokens = inputs.shape[1]
         if self._position + n_input_tokens > self._max_length:

+ 13 - 27
src/petals/utils/generation_algorithms.py

@@ -83,41 +83,27 @@ class BeamSearchAlgorithm(DecodingAlgorithm):
         self._batch_beams = torch.zeros((batch_size, num_beams))
 
     def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
-        probs = torch.log_softmax(sorted_logits, -1)
+        logits = torch.log_softmax(logits, -1)
+        probs, topk_indices = torch.topk(logits, k=self.num_beams, dim=-1)
 
         hypo_ids = None
         if self._cur_num_beams > 1:
-            permuted_indexes = torch.cat(
-                [torch.arange(0, self.num_beams) * self.batch_size + i for i in range(self.batch_size)], dim=0
-            )
-            probs = probs[:, : self.num_beams][permuted_indexes]
-            probs = probs.view(self.batch_size, self.num_beams, self.num_beams)
+            probs = probs.reshape(self.batch_size, self.num_beams, self.num_beams)
             self._batch_beams = self._batch_beams[:, :, None] + probs
             self._batch_beams = self._batch_beams.view(self.batch_size, -1)
-            sorted_batch_beams, sorted_hypo_ids = torch.sort(self._batch_beams, descending=True, dim=-1)
-            self._batch_beams = sorted_batch_beams[:, : self.num_beams]
-            hypo_ids = sorted_hypo_ids[:, : self.num_beams]
+            self._batch_beams, hypo_ids = torch.topk(self._batch_beams, k=self.num_beams, dim=-1)
         else:
-            self._batch_beams = probs[: self.batch_size, : self.num_beams]
+            self._batch_beams = probs[:self.batch_size, :self.num_beams]
             self._cur_num_beams = self.num_beams
             hypo_ids = torch.tile(
-                torch.arange(self.num_beams),
+                torch.arange(self.num_beams, device=probs.device),
                 (self.batch_size, 1),
             )
 
-        return_hypos = []
-        return_tokens = []
-        for batch_idx in range(self.batch_size):
-            cur_beam = hypo_ids[batch_idx]
-            hypo_idx = batch_idx + torch.floor_divide(cur_beam, self.num_beams) * self.batch_size
-            return_hypos.append(hypo_idx)
-            return_tokens.append(sorted_indices[hypo_idx, cur_beam % self.num_beams].unsqueeze(-1))
-
-        return_indexes = torch.cat(
-            [torch.arange(0, self.batch_size) * self.num_beams + i for i in range(self.num_beams)], dim=0
-        )
-        return_tokens = torch.cat(return_tokens, 0)
-        return_hypos = torch.cat(return_hypos, 0)
-
-        return return_tokens[return_indexes], return_hypos[return_indexes]
+        return_hypos = (
+            torch.arange(self.batch_size, device=probs.device)[:, None] +
+            torch.div(hypo_ids, self.num_beams, rounding_mode="floor") * self.batch_size
+        ).reshape(-1)
+        return_tokens = topk_indices[return_hypos, (hypo_ids % self.num_beams).reshape(-1)].unsqueeze(-1)
+
+        return return_tokens, return_hypos