5
0
Эх сурвалжийг харах

Allow .generate() to reuse existing inference session (#132)

Alexander Borzunov 2 жил өмнө
parent
commit
e8fac92e59

+ 5 - 0
src/petals/client/inference_session.py

@@ -171,6 +171,11 @@ class InferenceSession:
         self._server_inputs = []  # Used in case of server failures to regenerate attention caches on new servers
         self._position = 0
         self._max_length = max_length
+        self.last_token_id = None
+
+    @property
+    def position(self) -> int:
+        return self._position
 
     def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
         server_sessions = []

+ 41 - 11
src/petals/client/remote_generation.py

@@ -1,8 +1,10 @@
+import contextlib
 from typing import List, Optional
 
 import torch
 from hivemind.utils.logging import get_logger
 
+from petals.client.inference_session import InferenceSession
 from petals.utils.generation_algorithms import (
     BeamSearchAlgorithm,
     DecodingAlgorithm,
@@ -23,9 +25,20 @@ class RemoteGenerationMixin:
         - *multinomial sampling*.
         - *beam-search decoding*
 
-    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
+    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
+    However, it has some differences for remote usage.
     """
 
+    def inference_session(self, **kwargs) -> InferenceSession:
+        """
+        Returns an inference session for the model's RemoteSequential module.
+
+        :param max_length: Maximal expected length of inference results. Servers use this parameter
+                           to calculate the size of attention caches allocated to this client.
+        """
+
+        return self.transformer.h.inference_session(**kwargs)
+
     @torch.no_grad()
     def generate(
         self,
@@ -43,6 +56,8 @@ class RemoteGenerationMixin:
         decoding_algorithm: Optional[DecodingAlgorithm] = None,
         provided_constraints: List[ABCBloomConstraint] = [],
         num_return_sequences: Optional[int] = None,
+        *,
+        session: Optional[InferenceSession] = None,
         **model_kwargs,
     ) -> torch.LongTensor:
         """
@@ -74,8 +89,6 @@ class RemoteGenerationMixin:
         assert (
             model_kwargs.get("stopping_criteria", None) is None
         ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
-        if inputs is not None:
-            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
         prefix_length = 0 if inputs is None else inputs.size(1)
         prefix_length += self.config.pre_seq_len
 
@@ -83,8 +96,6 @@ class RemoteGenerationMixin:
         pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
         eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
 
-        batch_size = inputs.size(0)
-
         assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
         if max_length is not None and max_new_tokens is None:
             max_new_tokens = max_length - prefix_length
@@ -92,9 +103,22 @@ class RemoteGenerationMixin:
         elif max_length is None and max_new_tokens is not None:
             max_length = prefix_length + max_new_tokens
 
-        if inputs is None:
-            assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
-            inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
+        if num_beams > 1 and session is not None:
+            raise NotImplementedError(
+                "Reusing inference session in .generate() along with beam search is not supported yet"
+            )
+
+        if inputs is not None:
+            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
+            if session is not None and session.last_token_id is not None:
+                inputs = torch.cat([session.last_token_id, inputs], dim=1)
+        else:
+            if session is not None and session.last_token_id is not None:
+                inputs = session.last_token_id
+            else:
+                assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
+                inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
+        batch_size = inputs.size(0)
 
         if decoding_algorithm is None:
             if do_sample:
@@ -109,7 +133,8 @@ class RemoteGenerationMixin:
             if batch_size > 1:
                 # TODO: resolve padding problem
                 logger.warning(
-                    f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
+                    f"You set batch_size {batch_size} within beam search generation. "
+                    f"Be careful, results on sequences with different length may be padded wrong way"
                 )
 
         if num_return_sequences is None:
@@ -127,7 +152,11 @@ class RemoteGenerationMixin:
             provided_constraints=provided_constraints,
         )
 
-        with self.transformer.h.inference_session(max_length=max_length) as sess:
+        if session is None:
+            context_manager = self.inference_session(max_length=max_length)
+        else:
+            context_manager = contextlib.nullcontext(session)  # Doesn't actually enter session or exit from it
+        with context_manager as session:
             outputs = []
             # Find samples with padded inputs.
             # They will be changed before all of the samples have right length.
@@ -145,7 +174,7 @@ class RemoteGenerationMixin:
                     prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
                     embs = torch.cat([prompts, embs], dim=1)
                 embs = self.transformer.word_embeddings_layernorm(embs)
-                hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+                hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
                 hidden_state = self.transformer.ln_f(hidden_state)
                 lm_logits = self.lm_head(hidden_state)
 
@@ -166,6 +195,7 @@ class RemoteGenerationMixin:
                         outputs[i - 1] = outputs[i - 1][hypo_ids]
 
                 outputs.append(last_token_id)
+                session.last_token_id = last_token_id
                 seq_idx += 1
                 if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
                     break