|
@@ -1,8 +1,10 @@
|
|
|
+import contextlib
|
|
|
from typing import List, Optional
|
|
|
|
|
|
import torch
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
+from petals.client.inference_session import InferenceSession
|
|
|
from petals.utils.generation_algorithms import (
|
|
|
BeamSearchAlgorithm,
|
|
|
DecodingAlgorithm,
|
|
@@ -23,9 +25,20 @@ class RemoteGenerationMixin:
|
|
|
- *multinomial sampling*.
|
|
|
- *beam-search decoding*
|
|
|
|
|
|
- This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences for remote usage.
|
|
|
+ This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
|
|
|
+ However, it has some differences for remote usage.
|
|
|
"""
|
|
|
|
|
|
+ def inference_session(self, **kwargs) -> InferenceSession:
|
|
|
+ """
|
|
|
+ Returns an inference session for the model's RemoteSequential module.
|
|
|
+
|
|
|
+ :param max_length: Maximal expected length of inference results. Servers use this parameter
|
|
|
+ to calculate the size of attention caches allocated to this client.
|
|
|
+ """
|
|
|
+
|
|
|
+ return self.transformer.h.inference_session(**kwargs)
|
|
|
+
|
|
|
@torch.no_grad()
|
|
|
def generate(
|
|
|
self,
|
|
@@ -43,6 +56,8 @@ class RemoteGenerationMixin:
|
|
|
decoding_algorithm: Optional[DecodingAlgorithm] = None,
|
|
|
provided_constraints: List[ABCBloomConstraint] = [],
|
|
|
num_return_sequences: Optional[int] = None,
|
|
|
+ *,
|
|
|
+ session: Optional[InferenceSession] = None,
|
|
|
**model_kwargs,
|
|
|
) -> torch.LongTensor:
|
|
|
"""
|
|
@@ -74,8 +89,6 @@ class RemoteGenerationMixin:
|
|
|
assert (
|
|
|
model_kwargs.get("stopping_criteria", None) is None
|
|
|
), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
|
|
|
- if inputs is not None:
|
|
|
- assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
|
|
prefix_length = 0 if inputs is None else inputs.size(1)
|
|
|
prefix_length += self.config.pre_seq_len
|
|
|
|
|
@@ -83,8 +96,6 @@ class RemoteGenerationMixin:
|
|
|
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
|
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
|
|
|
|
- batch_size = inputs.size(0)
|
|
|
-
|
|
|
assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
|
|
|
if max_length is not None and max_new_tokens is None:
|
|
|
max_new_tokens = max_length - prefix_length
|
|
@@ -92,9 +103,22 @@ class RemoteGenerationMixin:
|
|
|
elif max_length is None and max_new_tokens is not None:
|
|
|
max_length = prefix_length + max_new_tokens
|
|
|
|
|
|
- if inputs is None:
|
|
|
- assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
|
- inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
|
|
+ if num_beams > 1 and session is not None:
|
|
|
+ raise NotImplementedError(
|
|
|
+ "Reusing inference session in .generate() along with beam search is not supported yet"
|
|
|
+ )
|
|
|
+
|
|
|
+ if inputs is not None:
|
|
|
+ assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
|
|
|
+ if session is not None and session.last_token_id is not None:
|
|
|
+ inputs = torch.cat([session.last_token_id, inputs], dim=1)
|
|
|
+ else:
|
|
|
+ if session is not None and session.last_token_id is not None:
|
|
|
+ inputs = session.last_token_id
|
|
|
+ else:
|
|
|
+ assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
|
|
|
+ inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
|
|
|
+ batch_size = inputs.size(0)
|
|
|
|
|
|
if decoding_algorithm is None:
|
|
|
if do_sample:
|
|
@@ -109,7 +133,8 @@ class RemoteGenerationMixin:
|
|
|
if batch_size > 1:
|
|
|
# TODO: resolve padding problem
|
|
|
logger.warning(
|
|
|
- f"You set batch_size {batch_size} within beam search generation. Be careful, results on sequences with different length may be padded wrong way"
|
|
|
+ f"You set batch_size {batch_size} within beam search generation. "
|
|
|
+ f"Be careful, results on sequences with different length may be padded wrong way"
|
|
|
)
|
|
|
|
|
|
if num_return_sequences is None:
|
|
@@ -127,7 +152,11 @@ class RemoteGenerationMixin:
|
|
|
provided_constraints=provided_constraints,
|
|
|
)
|
|
|
|
|
|
- with self.transformer.h.inference_session(max_length=max_length) as sess:
|
|
|
+ if session is None:
|
|
|
+ context_manager = self.inference_session(max_length=max_length)
|
|
|
+ else:
|
|
|
+ context_manager = contextlib.nullcontext(session) # Doesn't actually enter session or exit from it
|
|
|
+ with context_manager as session:
|
|
|
outputs = []
|
|
|
# Find samples with padded inputs.
|
|
|
# They will be changed before all of the samples have right length.
|
|
@@ -145,7 +174,7 @@ class RemoteGenerationMixin:
|
|
|
prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
|
|
|
embs = torch.cat([prompts, embs], dim=1)
|
|
|
embs = self.transformer.word_embeddings_layernorm(embs)
|
|
|
- hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
+ hidden_state = session.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
|
|
|
hidden_state = self.transformer.ln_f(hidden_state)
|
|
|
lm_logits = self.lm_head(hidden_state)
|
|
|
|
|
@@ -166,6 +195,7 @@ class RemoteGenerationMixin:
|
|
|
outputs[i - 1] = outputs[i - 1][hypo_ids]
|
|
|
|
|
|
outputs.append(last_token_id)
|
|
|
+ session.last_token_id = last_token_id
|
|
|
seq_idx += 1
|
|
|
if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
|
|
|
break
|