Ver código fonte

Make client compatible with transformers' GenerationMixin (#464)

This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.

Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.

### Breaking changes

If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:

```python
# Using default session
with model.inference_session(max_length=100):
    output_ids = model.generate(input_ids, max_new_tokens=3)

# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
    output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```

Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
Alexander Borzunov 2 anos atrás
pai
commit
de2475f31c

+ 1 - 0
.github/workflows/run-tests.yaml

@@ -41,6 +41,7 @@ jobs:
           pip install .[dev]
       - name: Test
         run: |
+          set -x  # Print executed commands
           export MODEL_NAME="${{ matrix.model }}"
           export REF_NAME="${{ matrix.model }}"
           export ADAPTER_NAME="${{ matrix.model == 'bigscience/bloom-560m' && 'artek0chumak/bloom-560m-safe-peft' || '' }}"

+ 6 - 8
src/petals/client/from_pretrained.py

@@ -3,7 +3,7 @@ import json
 import os
 import re
 import tempfile
-import threading
+from contextvars import ContextVar
 from typing import List, Optional, Tuple, Union
 
 import torch
@@ -47,18 +47,16 @@ class FromPretrainedMixin:
     )
 
 
-_shard_config = threading.local()
-_shard_config.ignored_keys = None
+_ignored_keys = ContextVar("ignored_keys", default=None)
 
 
 @contextlib.contextmanager
 def ignore_keys(patterns: List[str]):
+    token = _ignored_keys.set(patterns)
     try:
-        prev_patterns = _shard_config.ignored_keys
-        _shard_config.ignored_keys = patterns
         yield
     finally:
-        _shard_config.ignored_keys = prev_patterns
+        _ignored_keys.reset(token)
 
 
 def patched_get_checkpoint_shard_files(
@@ -66,7 +64,7 @@ def patched_get_checkpoint_shard_files(
 ) -> Tuple[List[str], dict]:
     """Same as modeling_utils.get_checkpoint_shard_files(), but does not download shards for the ignored keys."""
 
-    should_ignore_keys = _shard_config.ignored_keys is not None
+    should_ignore_keys = _ignored_keys.get() is not None
     tempdir_ctx = tempfile.TemporaryDirectory() if should_ignore_keys else contextlib.nullcontext()
     with tempdir_ctx as tempdir:
         if should_ignore_keys:
@@ -77,7 +75,7 @@ def patched_get_checkpoint_shard_files(
             index["weight_map"] = {
                 param_name: filename
                 for param_name, filename in index["weight_map"].items()
-                if all(re.search(pattern, param_name) is None for pattern in _shard_config.ignored_keys)
+                if all(re.search(pattern, param_name) is None for pattern in _ignored_keys.get())
             }
             n_loaded_shards = len(set(index["weight_map"].values()))
             logger.debug(f"Loading {n_loaded_shards} shards out of {n_original_shards}")

+ 11 - 1
src/petals/client/inference_session.py

@@ -230,7 +230,7 @@ class InferenceSession:
         self._server_sessions = []
         self._position = 0
         self._max_length = max_length
-        self.last_token_id = None
+        self.output_ids = None
 
     @property
     def num_blocks(self) -> int:
@@ -377,3 +377,13 @@ class InferenceSession:
 
     def __del__(self):
         self.close()
+
+    @property
+    def last_token_id(self) -> Optional[torch.Tensor]:  # Backward compatibility with Petals < 2.1.0
+        return self.output_ids[:, -1:] if self.output_ids is not None else None
+
+    @last_token_id.setter
+    def last_token_id(self, value: torch.Tensor):  # Backward compatibility with Petals < 2.1.0
+        if self.output_ids is None:
+            raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
+        self.output_ids[:, -1:] = value

+ 2 - 2
src/petals/client/lm_head.py

@@ -70,8 +70,8 @@ class LMHead(nn.Module):
         if not self._bf16_warning_shown:
             if self.weight.numel() * 4 < 0.9 * psutil.virtual_memory().total:
                 logger.warning(
-                    "Running the client with dtype bfloat16 on CPU may be slow, since your CPU doesn't support AVX512. "
-                    "Consider loading the model with torch_dtype='float32'"
+                    "Running the model in bfloat16 on CPU will be slow since your CPU does not support AVX512. "
+                    "To speed it up, load the model in float32 using .from_pretrained(..., torch_dtype=torch.float32)"
                 )
             self._bf16_warning_shown = True
 

+ 2 - 2
src/petals/client/ptune.py

@@ -76,9 +76,9 @@ def force_non_empty_weights():
     [1] https://github.com/huggingface/transformers/blob/ab9fe45236cd99b8797df78219438f8f6662bb42/src/transformers/modeling_utils.py#L2515
     """
 
+    possibly_patched_register_parameter = nn.Module.register_parameter
+    nn.Module.register_parameter = _original_register_parameter
     try:
-        possibly_patched_register_parameter = nn.Module.register_parameter
-        nn.Module.register_parameter = _original_register_parameter
         yield
     finally:
         nn.Module.register_parameter = possibly_patched_register_parameter

+ 110 - 317
src/petals/client/remote_generation.py

@@ -1,349 +1,142 @@
 import contextlib
-from typing import List, Optional
+import dataclasses
+from contextvars import ContextVar
+from typing import ContextManager, List, Optional
 
 import torch
+import transformers
 from hivemind.utils.logging import get_logger
+from transformers.generation.utils import ModelOutput
 
 from petals.client.inference_session import InferenceSession
-from petals.utils.generation_algorithms import (
-    BeamSearchAlgorithm,
-    DecodingAlgorithm,
-    GreedyAlgorithm,
-    NucleusAlgorithm,
-    SamplingAlgorithm,
-    TopKAlgorithm,
-)
-from petals.utils.generation_constraints import ABCBloomConstraint, EosConstraint
+from petals.client.remote_sequential import RemoteSequential
+from petals.utils.misc import DUMMY, docstring_from
 
 logger = get_logger(__name__)
 
 
-class RemoteGenerationMixin:
-    """
-    A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
-    The class exposes can be used for:
-        - *greedy decoding*.
-        - *multinomial, top-k and top-p sampling*.
-        - *beam-search decoding*
-
-    This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it.
-    However, it has some differences for remote usage.
-    """
-
-    def inference_session(self, **kwargs) -> InferenceSession:
-        """
-        Returns an inference session for the model's RemoteSequential module.
+@dataclasses.dataclass(frozen=True)
+class RemotePastKeyValues:
+    """A mock class representing the fact that `past_key_values` do exist but are stored on remote servers."""
 
-        :param max_length: Maximal expected length of inference results. Servers use this parameter
-                           to calculate the size of attention caches allocated to this client.
-        """
+    hypo_ids: Optional[torch.LongTensor] = None
 
-        return self.transformer.h.inference_session(**kwargs)
+    def __getitem__(self, _index: int) -> List[torch.Tensor]:
+        return [DUMMY]  # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
 
-    @torch.inference_mode()
-    def generate(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        *,
-        do_sample: Optional[bool] = None,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-        num_beams: Optional[int] = 1,
-        bos_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        max_length: Optional[int] = None,
-        max_new_tokens: Optional[int] = None,
-        decoding_algorithm: Optional[DecodingAlgorithm] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-        num_return_sequences: Optional[int] = None,
-        session: Optional[InferenceSession] = None,
-    ) -> torch.LongTensor:
-        """
-        Generates sequences of token ids for models with a language modeling head.
 
-        :param inputs: The input tokens to the model.
-        :param do_sample: Whether to sample from the model predictions or take the argmax.
-        :param temperature: The temperature to use for sampling.
-        :param top_k: The number of results to return.
-        :param top_p: The cumulative probability of results to return.
-        :param num_beams: The number of beams to use for beam search.
-        :param bos_token_id: The id of the beginning of sentence token.
-        :param eos_token_id: The id of the end of sentence token.
-        :param pad_token_id: The id of the padding token.
-        :param max_length: The maximum number of tokens in the output (including input tokens).
-        :param max_new_tokens: The maximum number of tokens to generate.
-        :param decoding_algorithm: The decoding algorithm to use.
-        :param provided_constraints: A list of constraints to use.
-        :param num_return_sequences: How many hypothesis from the beam will be in output.
-        """
+_skipped_tokens = ContextVar("skipped_tokens", default=0)
 
-        prefix_length = 0 if inputs is None else inputs.size(1)
-        prefix_length += self.config.pre_seq_len
 
-        bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
-        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
-        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
+class _SkipTokensMixin:
+    # This override is used in RemoteGenerationMixin by has to be defined in a class not named as "GenerationMixin"
+    # due to how transformers.PreTrainedModel.can_generate() works
+    def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
+        input_ids = input_ids[:, _skipped_tokens.get() :]
+        _skipped_tokens.set(0)
+        return super().prepare_inputs_for_generation(input_ids, **kwargs)
 
-        assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
-        if max_length is not None and max_new_tokens is None:
-            max_new_tokens = max_length - prefix_length
-            assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
-        elif max_length is None and max_new_tokens is not None:
-            max_length = prefix_length + max_new_tokens
 
-        resuming_session = session is not None and session.last_token_id is not None
-        if num_beams > 1 and resuming_session:
-            raise NotImplementedError(
-                "Resuming inference session in .generate() along with beam search is not supported yet"
-            )
+class RemoteGenerationMixin(_SkipTokensMixin):
+    """
+    This class is an upgrade to `transformers.GenerationMixin` that:
+
+    - Designed to be compatible with most `transformers.GenerationMixin` strategies and options
+    - Supports generation inside a remote InferenceSession, so that remote servers store your attention caches and
+      you don't have to rerun the prefix through all the servers to generate each new token
+    - Supports multiple `.generate()` calls inside one InferenceSession, so you can easily run interactive generation
+      by showing tokens on the fly (multiple calls like `.generate(None, max_new_tokens=1, ...)`) or
+      accept prompts from a user in a chat bot (multiple calls like `.generate(new_prompts, ...)`).
+    - If there is no active session, `.generate()` will create a new InferenceSession with proper `max_length`.
+      Otherwise, `.generate()` will use the active session. You can use the `session=...` argument to override that.
+    """
 
-        if inputs is not None:
-            assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
-            if resuming_session:
-                inputs = torch.cat([session.last_token_id, inputs], dim=1)
-        else:
-            if resuming_session:
-                inputs = session.last_token_id
-            else:
-                assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
-                inputs = torch.tensor([[bos_token_id]] * num_beams, dtype=torch.long, device=self.device)
-        batch_size = inputs.size(0)
+    @docstring_from(RemoteSequential.active_session)
+    @property
+    def active_session(self) -> Optional[InferenceSession]:
+        return self.transformer.h.active_session
 
-        if decoding_algorithm is None:
-            if do_sample:
-                decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
-            elif num_beams is not None and num_beams > 1:
-                decoding_algorithm = BeamSearchAlgorithm(num_beams, batch_size=batch_size)
-            else:
-                if top_k is not None or top_p is not None:
-                    logger.warning("You passed top_k or top_p but did not pass do_sample=True. Running greedy sampling")
-                decoding_algorithm = GreedyAlgorithm()
+    @docstring_from(RemoteSequential.use_session)
+    def use_session(self, session: Optional[InferenceSession]) -> ContextManager[InferenceSession]:
+        return self.transformer.h.use_session(session)
 
-        if num_beams > 1:
-            inputs = torch.cat([inputs] * num_beams, dim=0)
-            if batch_size > 1:
-                # TODO: resolve padding problem
-                logger.warning(
-                    f"You set batch_size {batch_size} within beam search generation. "
-                    f"Be careful, results on sequences with different length may be padded wrong way"
-                )
+    @docstring_from(RemoteSequential.inference_session)
+    def inference_session(self, **kwargs) -> ContextManager[InferenceSession]:
+        return self.transformer.h.inference_session(**kwargs)
 
-        if num_return_sequences is None:
-            num_return_sequences = 1
+    @docstring_from(transformers.GenerationMixin.generate.__doc__)
+    def generate(
+        self, inputs: Optional[torch.Tensor] = None, *args, session: Optional[InferenceSession] = None, **kwargs
+    ):
+        self._fix_generate_kwargs(kwargs)
+
+        if session is not None:
+            # If a session specified explicitly, use it
+            context_manager = self.use_session(session)
+        elif self.active_session is not None:
+            # If there's an active session, don't do anything
+            context_manager = contextlib.nullcontext(self.active_session)
+        else:
+            # If there's no active session, create a new one
 
-        assert num_return_sequences <= num_beams, (
-            f"You want more sequences than the beam has."
-            " Check num_return_sequences: {num_return_sequences} and num_beams: {num_beams}."
-        )
+            max_length = kwargs.get("max_length")
+            max_new_tokens = kwargs.get("max_new_tokens")
+            assert (max_length is None) != (
+                max_new_tokens is None
+            ), "You should set `max_length` or `max_new_tokens` (but not both) to reserve server-side attention caches"
 
-        constraints = self._get_constraints(
-            inputs=inputs,
-            eos_token_id=eos_token_id,
-            pad_token_id=pad_token_id,
-            provided_constraints=provided_constraints,
-        )
+            if max_length is not None:
+                session_max_length = max_length
+            else:
+                session_max_length = (inputs.shape[1] if inputs is not None else 0) + max_new_tokens
+            context_manager = self.inference_session(max_length=session_max_length)
 
-        if session is None:
-            context_manager = self.inference_session(max_length=max_length)
-        else:
-            context_manager = contextlib.nullcontext(session)  # Doesn't actually enter session or exit from it
         with context_manager as session:
-            outputs = []
-            # Find samples with padded inputs.
-            # They will be changed before all of the samples have right length.
-            if torch.any(inputs == pad_token_id):  # TODO: move to prepare_inputs
-                outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
+            # Prepend the tokens from the previous .generate() call
+            n_prev_tokens = session.output_ids.shape[1] if session.output_ids is not None else 0
+            if n_prev_tokens > 0:
+                if kwargs.get("num_beams", 1) > 1:
+                    logger.warning(
+                        "Beam search will not work properly in the resumed petals.InferenceSession "
+                        "since intermediate beam entries are lost"
+                    )
+
+                if inputs is not None:
+                    inputs = torch.cat([session.output_ids, inputs], dim=1)
+                else:
+                    inputs = session.output_ids
+
+                # Don't actually run all previous tokens through the transformer,
+                # but keep them for transformers.GenerationMixin (e.g., to compute repetition_penalty)
+                _skipped_tokens.set(max(0, n_prev_tokens - 1))
+
+            result = super().generate(inputs, *args, **kwargs)
+
+            sequences = result.sequences if isinstance(result, ModelOutput) else result
+            # Save tokens from this .generate() call
+            session.output_ids = sequences
+            # Crop the last tokens from the previous call
+            sequences = sequences[:, n_prev_tokens:].clone()
+            if isinstance(result, ModelOutput):
+                result.sequences = sequences
             else:
-                outputs += [inputs]
-            last_token_id = None
-            seq_idx = outputs[0].size(1)
-            hypo_ids = torch.arange(outputs[0].size(0))
-            while True:
-                hidden_state = self.transformer.word_embeddings(outputs[-1])
-                intermediate_prompts = None
-                if self.config.pre_seq_len > 0 and len(outputs) == 1:
-                    prompts, intermediate_prompts = self.transformer.get_prompt(hidden_state.size(0))
-                    hidden_state = torch.cat([prompts, hidden_state], dim=1)
-                hidden_state = self.transformer.word_embeddings_layernorm(hidden_state)
+                result = sequences
 
-                hidden_state = session.step(hidden_state, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
+        return result
 
-                hidden_state = self.transformer.ln_f(hidden_state)
-                lm_logits = self.lm_head(hidden_state)
+    @staticmethod
+    def _fix_generate_kwargs(kwargs: dict) -> dict:
+        # Suppress inappropriate "Both max_new_tokens and max_length" HF warning
+        if "max_length" in kwargs and kwargs["max_length"] is None:
+            del kwargs["max_length"]
 
-                for constraint in constraints:
-                    lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
-                last_token_id, hypo_ids = decoding_algorithm(lm_logits)
+        # Support do_sample = {0, 1} for backward compatibility with Petals < 2.1.0
+        do_sample = kwargs.get("do_sample")
+        if isinstance(do_sample, int):
+            kwargs["do_sample"] = bool(do_sample)
 
-                # If some samples were padded, change only these samples
-                if seq_idx < inputs.size(1):
-                    pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
-                    last_token_id = (~pad_token_mask) * inputs[
-                        :, seq_idx : seq_idx + 1
-                    ] + pad_token_mask * last_token_id
-
-                # TODO: refactor outputs
-                if num_beams > 1:
-                    for i in range(len(outputs), 1, -1):
-                        outputs[i - 1] = outputs[i - 1][hypo_ids]
-
-                outputs.append(last_token_id)
-                session.last_token_id = last_token_id
-                seq_idx += 1
-                if torch.all(last_token_id == eos_token_id) or len(outputs) > max_new_tokens:
-                    break
-
-        outputs = torch.cat(outputs, dim=-1)
-
-        if resuming_session:
-            outputs = outputs[:, 1:]
-        if num_beams > 1:
-            pre_return_idx = [
-                torch.arange(idx, num_return_sequences * batch_size, batch_size) for idx in range(batch_size)
-            ]
-            return_idx = torch.cat(pre_return_idx, dim=0)
-            outputs = outputs[return_idx]
-
-        return outputs
-
-    def greedy_search(
-        self,
-        input_ids: torch.LongTensor,
-        max_length: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> torch.LongTensor:
-        """
-        Generates sequences of token ids for models with a language modeling head. Uses greedy search.
-
-        :param input_ids: The input tokens to the model.
-        :param max_length: The maximum length of the sequence to generate.
-        :param pad_token_id: The id of the padding token.
-        :param eos_token_id: The id of the end of sentence token.
-        :param provided_constraints: A list of constraints to use.
-        """
-        return self.generate(
-            inputs=input_ids,
-            max_new_tokens=max_length,
-            pad_token_id=pad_token_id,
-            eos_token_id=eos_token_id,
-            decoding_algorithm=GreedyAlgorithm(),
-            provided_constraints=provided_constraints,
-        )
-
-    def sample(
-        self,
-        input_ids: torch.LongTensor,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-        max_length: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> torch.LongTensor:
-        """
-        Generates sequences of token ids for models with a language modeling head. Uses multinomial sampling.
-        If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
-
-        :param: input_ids: The input tokens to the model.
-        :param: temperature: The temperature to use for sampling.
-        :param: top_k: The number of samples to use for top_k sampling.
-        :param: top_p: The probability of using top_p sampling.
-        :param: max_length: The maximum length of the sequence to generate.
-        :param: pad_token_id: The id of the padding token.
-        :param: eos_token_id: The id of the end of sentence token.
-        :param: provided_constraints: A list of constraints to use.
-        """
-
-        return self.generate(
-            inputs=input_ids,
-            max_new_tokens=max_length,
-            pad_token_id=pad_token_id,
-            eos_token_id=eos_token_id,
-            decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
-            provided_constraints=provided_constraints,
-        )
-
-    def beam_search(
-        self,
-        input_ids: torch.LongTensor,
-        num_beams: int = 1,
-        max_length: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> torch.LongTensor:
-        """
-        Generates sequences of token ids for models with a language modeling head. Uses beam search.
-
-        :param input_ids: The input tokens to the model.
-        :param num_beams: The number of beams to use.
-        :param max_length: The maximum length of the sequence to generate.
-        :param pad_token_id: The id of the padding token.
-        :param eos_token_id: The id of the end of sentence token.
-        :param provided_constraints: A list of constraints to use.
-        """
-        decoding_algorithm = BeamSearchAlgorithm(
-            num_beams=num_beams,
-            batch_size=input_ids.size(0),
-        )
-        return self.generate(
-            inputs=input_ids,
-            num_beams=num_beams,
-            max_new_tokens=max_length,
-            pad_token_id=pad_token_id,
-            eos_token_id=eos_token_id,
-            decoding_algorithm=decoding_algorithm,
-            provided_constraints=provided_constraints,
-        )
-
-    def beam_sample(
-        self,
-        input_ids: torch.LongTensor,
-        max_length: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> torch.LongTensor:
-        raise NotImplementedError
-
-    def group_beam_search(
-        self,
-        input_ids: torch.LongTensor,
-        max_length: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        eos_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> torch.LongTensor:
-        raise NotImplementedError
-
-    def _choose_sample_algorithm(
-        self,
-        temperature: float = 1.0,
-        top_k: Optional[int] = None,
-        top_p: Optional[float] = None,
-    ) -> DecodingAlgorithm:
-        if (top_k is not None) and (top_p is not None):
-            raise ValueError("You have to provide only top_k or top_p for sampling")
-        if top_k is not None:
-            return TopKAlgorithm(top_k, temperature)
-        elif top_p is not None:
-            return NucleusAlgorithm(top_p, temperature)
-        else:
-            return SamplingAlgorithm(temperature)
+        return kwargs
 
-    def _get_constraints(
-        self,
-        inputs: Optional[torch.Tensor] = None,
-        eos_token_id: Optional[int] = None,
-        pad_token_id: Optional[int] = None,
-        provided_constraints: List[ABCBloomConstraint] = [],
-    ) -> List[ABCBloomConstraint]:
-        constraints = []
-        constraints.extend(provided_constraints)
-        constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
-        return constraints
+    @staticmethod
+    def _reorder_cache(past_key_values: RemotePastKeyValues, beam_idx: torch.LongTensor) -> RemotePastKeyValues:
+        return dataclasses.replace(past_key_values, hypo_ids=beam_idx)

+ 47 - 8
src/petals/client/remote_sequential.py

@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+from contextlib import contextmanager
+from contextvars import ContextVar
 from typing import Optional, Union
 
 import torch
@@ -11,7 +13,6 @@ from petals.client.inference_session import InferenceSession
 from petals.client.routing import RemoteSequenceManager
 from petals.client.sequential_autograd import _RemoteSequentialAutogradFunction
 from petals.data_structures import UID_DELIMITER
-from petals.utils.misc import DUMMY
 
 logger = get_logger(__name__)
 
@@ -46,11 +47,52 @@ class RemoteSequential(nn.Module):
             sequence_manager = RemoteSequenceManager(config, block_uids, dht=dht, **kwargs)
         self.sequence_manager = sequence_manager
 
-    def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
+        self._active_session = ContextVar("active_session", default=None)
+
+    def forward(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
         assert inputs.ndim == 3, "inputs must be a tensor of shape [batch_size, seq_length, hidden_size]"
-        assert inputs.shape[1] <= 2048, "The sequence length is capped at 2048 tokens in this version"
-        outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
-        return outputs
+        if self.active_session is None:
+            assert all(v is None for v in kwargs.values()), f"Extra kwargs are not supported in forward: {kwargs}"
+            return _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
+        else:
+            return self.active_session.step(inputs, prompts, **kwargs)
+
+    @property
+    def active_session(self) -> Optional[InferenceSession]:
+        """
+        If called inside `with model.inference_session(...):` or `with model.use_session(...):`,
+        returns an active InferenceSession. Otherwise, returns None.
+        """
+
+        return self._active_session.get()
+
+    @property
+    def position(self) -> int:
+        """Returns the prefix length (in tokens) in the active inference session or zero if no session is active."""
+
+        return self.active_session.position if self.active_session is not None else 0
+
+    @contextmanager
+    def use_session(self, session: Optional[InferenceSession]) -> InferenceSession:
+        """Inside this context, forward() will use an _existing_ InferenceSession provided as the argument."""
+
+        token = self._active_session.set(session)
+        try:
+            yield session
+        finally:
+            self._active_session.reset(token)
+
+    @contextmanager
+    def inference_session(self, **kwargs) -> InferenceSession:
+        """
+        Inside this context, forward() will use a _new_ InferenceSession created with given parameters.
+
+        :param max_length: Maximal expected length of inference results. Servers use this parameter
+                           to calculate the size of attention caches allocated to this client.
+        """
+
+        with InferenceSession(self.sequence_manager, **kwargs) as session, self.use_session(session):
+            yield session
 
     def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
         return RemoteSequential(
@@ -65,8 +107,5 @@ class RemoteSequential(nn.Module):
     def __len__(self):
         return len(self.sequence_manager)
 
-    def inference_session(self, **kwargs) -> InferenceSession:
-        return InferenceSession(self.sequence_manager, **kwargs)
-
     def extra_repr(self) -> str:
         return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

+ 1 - 1
src/petals/client/sequential_autograd.py

@@ -230,7 +230,7 @@ class _RemoteSequentialAutogradFunction(torch.autograd.Function):
     def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
         batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
         input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
-        if is_dummy(prompts):
+        if prompts is None or is_dummy(prompts):
             prompt_batches = [DUMMY] * len(input_batches)
         else:
             prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)

+ 27 - 15
src/petals/models/bloom/model.py

@@ -10,7 +10,7 @@ from transformers.models.bloom import BloomForCausalLM, BloomForSequenceClassifi
 from petals.client.from_pretrained import FromPretrainedMixin
 from petals.client.lm_head import LMHead
 from petals.client.ptune import PTuneMixin
-from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
 from petals.client.remote_sequential import RemoteSequential
 from petals.models.bloom.config import DistributedBloomConfig
 
@@ -39,16 +39,15 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
     def forward(
         self,
         input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
+        past_key_values: Optional[RemotePastKeyValues] = None,
         attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
+        head_mask: Optional[torch.LongTensor] = None,
+        inputs_embeds: Optional[torch.LongTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
     ):
-        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
-
-        for k, v in kwargs.items():
-            if not (v is None or v is False):
-                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
-
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         elif input_ids is not None:
@@ -59,21 +58,34 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         else:
             raise ValueError("You have to specify either input_ids or inputs_embeds")
 
+        # The causal mask will be added on the server-side
+        assert (
+            attention_mask is None or (attention_mask == 1).all()
+        ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert head_mask is None, f"Custom head masks are not supported, {head_mask=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
+        assert not output_attentions, f"{output_attentions=} is not supported"
+        assert not output_hidden_states, f"{output_hidden_states=} is not supported"
+        assert return_dict is None or return_dict, f"{return_dict=} is not supported"
+
         if inputs_embeds is None:
             inputs_embeds = self.word_embeddings(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.h.position == 0:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        else:
+            prompts = intermediate_prompts = None
 
         hidden_states = self.word_embeddings_layernorm(inputs_embeds)
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
-        else:
-            hidden_states = self.h(hidden_states)
+        hidden_states = self.h(
+            hidden_states,
+            prompts=intermediate_prompts,
+            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
+        )
 
         # Remove prefix
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@@ -84,7 +96,7 @@ class DistributedBloomModel(FromPretrainedMixin, PTuneMixin, BloomModel):
         hidden_states = hidden_states.view(output_shape)
         return BaseModelOutputWithPastAndCrossAttentions(
             last_hidden_state=hidden_states,
-            past_key_values=None,
+            past_key_values=RemotePastKeyValues(),
             hidden_states=None,
             attentions=None,
         )

+ 29 - 15
src/petals/models/llama/model.py

@@ -10,7 +10,7 @@ from transformers.models.llama import LlamaForCausalLM, LlamaForSequenceClassifi
 from petals.client.from_pretrained import FromPretrainedMixin
 from petals.client.lm_head import LMHead
 from petals.client.ptune import PTuneMixin
-from petals.client.remote_generation import RemoteGenerationMixin
+from petals.client.remote_generation import RemoteGenerationMixin, RemotePastKeyValues
 from petals.client.remote_sequential import RemoteSequential
 from petals.models.llama.config import DistributedLlamaConfig
 
@@ -39,16 +39,15 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
     def forward(
         self,
         input_ids: Optional[torch.LongTensor] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
         attention_mask: Optional[torch.Tensor] = None,
-        **kwargs,
+        position_ids: Optional[torch.LongTensor] = None,
+        past_key_values: Optional[RemotePastKeyValues] = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        use_cache: Optional[bool] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
     ) -> BaseModelOutputWithPast:
-        assert attention_mask is None, f"{self.__class__.__name__} does not support attention masks right now"
-
-        for k, v in kwargs.items():
-            if not (v is None or v is False):
-                logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
-
         if input_ids is not None and inputs_embeds is not None:
             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
         elif input_ids is not None:
@@ -59,21 +58,36 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         else:
             raise ValueError("You have to specify either input_ids or inputs_embeds")
 
+        # The causal mask will be added on the server-side
+        assert (
+            attention_mask is None or (attention_mask == 1).all()
+        ), f"Custom attention masks are not supported, {attention_mask=}"
+        assert (
+            position_ids is None or (position_ids[:, 1:] - position_ids[:, :-1] == 1).all()
+        ), f"Non-consecutive position_ids are not supported, {position_ids=}"
+        assert use_cache is None or use_cache, f"{use_cache=} is not supported"
+        assert not output_attentions, f"{output_attentions=} is not supported"
+        assert not output_hidden_states, f"{output_hidden_states=} is not supported"
+        assert return_dict is None or return_dict, f"{return_dict=} is not supported"
+
         if inputs_embeds is None:
             inputs_embeds = self.embed_tokens(input_ids)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
+        if self.config.tuning_mode and "ptune" in self.config.tuning_mode and self.layers.position == 0:
             batch_size = inputs_embeds.shape[0]
             prompts, intermediate_prompts = self.get_prompt(batch_size)
             inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
+        else:
+            prompts = intermediate_prompts = None
 
         hidden_states = inputs_embeds
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
-            hidden_states = self.layers(hidden_states, prompts=intermediate_prompts)
-        else:
-            hidden_states = self.layers(hidden_states)
+        hidden_states = self.layers(
+            hidden_states,
+            prompts=intermediate_prompts,
+            hypo_ids=past_key_values.hypo_ids if past_key_values is not None else None,
+        )
 
         # Remove prefix
         if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
@@ -84,7 +98,7 @@ class DistributedLlamaModel(FromPretrainedMixin, PTuneMixin, LlamaModel):
         hidden_states = hidden_states.view(output_shape)
         return BaseModelOutputWithPast(
             last_hidden_state=hidden_states,
-            past_key_values=None,
+            past_key_values=RemotePastKeyValues(),
             hidden_states=None,
             attentions=None,
         )

+ 1 - 1
src/petals/server/block_functions.py

@@ -196,7 +196,7 @@ async def iterate_rpc_inference(
             hypo_ids,
             points=point_per_piece,
             requested_uids=requested_uids,
-            type="short_inference" if can_merge_pools else "inference",
+            type="inference",
         )
 
         # A client may pass a tensor with 0 tokens. This is a special case that occurs, e.g.

+ 3 - 5
src/petals/server/task_prioritizer.py

@@ -14,9 +14,7 @@ class TaskPrioritizerBase(ABC):
 
 class DummyTaskPrioritizer(TaskPrioritizerBase):
     def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
-        # Inference steps (especially short ones) go first since they are more latency-sensitive
-        if kwargs.get("type") == "short_inference":
-            return 1.0
+        # Inference steps go first since they are more latency-sensitive
         if kwargs.get("type") == "inference":
-            return 2.0
-        return 3.0  # Forward, backward
+            return 1.0
+        return 2.0  # Forward, backward

+ 0 - 128
src/petals/utils/generation_algorithms.py

@@ -1,128 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Tuple
-
-import torch
-
-TokenIds = torch.Tensor
-HypoIds = torch.Tensor
-
-
-class DecodingAlgorithm(ABC):
-    """
-    An abstract class for decoding algorithms. Describes the base function of those algorithms:
-    they have to select new tokens and provide the corresponding hypotheses.
-    """
-
-    @abstractmethod
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        """
-        :param logits: A tensor of shape (batch_size, seq_length, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypotheses.
-        The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size)
-        """
-        pass
-
-
-class GreedyAlgorithm(DecodingAlgorithm):
-    """
-    The simplest algorithm for decoding. It selects the most probable token.
-    """
-
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        """
-        Returns the most probable token. The second returned object is always a range of integers
-        from 0 to batch_size - 1.
-        """
-        return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
-
-
-class SamplingAlgorithm(DecodingAlgorithm):
-    def __init__(self, temperature: float = 1.0):
-        self.temperature = temperature
-
-    def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        """
-        :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
-        :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
-        :return: A tuple of selected token ids and corresponding hypotheses.
-        The shape of the token ids is (batch_size, seq_length), and the shape of the hypotheses is (batch_size).
-        """
-        logits[indices_to_remove] = -float("Inf")
-        probs = torch.softmax(logits / self.temperature, -1)
-        return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
-
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        indices_to_remove = torch.full_like(logits, False, dtype=torch.bool)
-        return self.sample(logits, indices_to_remove)
-
-
-class TopKAlgorithm(SamplingAlgorithm):
-    def __init__(self, top_k: int, temperature: float = 1.0) -> None:
-        super().__init__(temperature=temperature)
-        self.top_k = top_k
-
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
-        return self.sample(logits, indices_to_remove)
-
-
-class NucleusAlgorithm(SamplingAlgorithm):
-    def __init__(self, top_p: float, temperature: float = 1.0) -> None:
-        super().__init__(temperature=temperature)
-        self.top_p = top_p
-
-    def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
-        sorted_logits, sorted_indices = torch.sort(logits, descending=False, dim=-1)
-        probs = torch.softmax(sorted_logits / self.temperature, -1)
-        cumulative_probs = torch.cumsum(probs, dim=-1)
-
-        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
-
-        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
-        return self.sample(logits, indices_to_remove)
-
-
-class BeamSearchAlgorithm(DecodingAlgorithm):
-    def __init__(self, num_beams: int, batch_size: int) -> None:
-        self.num_beams = num_beams
-        self.batch_size = batch_size
-
-        self._batch_beams = [list() for _ in range(batch_size)]
-
-    def __call__(self, logits: torch.Tensor):
-        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
-        probs = torch.log_softmax(sorted_logits, -1)
-
-        if len(self._batch_beams[0]) > 0:
-            for batch_idx in range(self.batch_size):
-                new_beams = []
-                cur_beams = self._batch_beams[batch_idx]
-                for beam_idx in range(len(cur_beams)):
-                    probs_idx = batch_idx + beam_idx * self.batch_size
-                    new_beam = cur_beams[beam_idx]
-                    for hypo_idx in range(self.num_beams):
-                        new_beams.append(
-                            (new_beam[0] + probs[probs_idx, hypo_idx].item(), beam_idx * self.num_beams + hypo_idx)
-                        )
-                self._batch_beams[batch_idx] = sorted(new_beams, reverse=True)[: self.num_beams]
-        else:
-            for batch_idx in range(self.batch_size):
-                for beam_idx in range(self.num_beams):
-                    self._batch_beams[batch_idx].append((probs[batch_idx, beam_idx].item(), beam_idx))
-
-        return_hypos = []
-        return_tokens = []
-        for batch_idx in range(self.batch_size):
-            cur_beam = self._batch_beams[batch_idx]
-            return_hypos.append(list())
-            return_tokens.append(list())
-            for beam in cur_beam:
-                beam_idx = beam[1] // self.num_beams
-                hypo_idx = batch_idx + beam_idx * self.batch_size
-                token_idx = beam[1] % self.num_beams
-                return_hypos[-1].append(hypo_idx)
-                return_tokens[-1].append([sorted_indices[hypo_idx, token_idx].item()])
-        return_hypos = [hypo_idx for hypo_indexes in zip(*return_hypos) for hypo_idx in hypo_indexes]
-        return_tokens = [token_idx for token_indexes in zip(*return_tokens) for token_idx in token_indexes]
-
-        return torch.tensor(return_tokens), torch.tensor(return_hypos)

+ 0 - 51
src/petals/utils/generation_constraints.py

@@ -1,51 +0,0 @@
-from abc import ABC
-
-import torch
-
-
-class ABCBloomConstraint(ABC):
-    """
-    Base class of all kind of decoding constraints. It can be used to implement a new constraint.
-    """
-
-    def __init__(self) -> None:
-        pass
-
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
-        """
-        This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
-        :param tokens_id: The token id of the last chosen token.
-        :param logits: The logits from the Bloom model.
-        :param hypo_ids: The hypothesis ids of the last tokens.
-        """
-        pass
-
-
-class EosConstraint(ABCBloomConstraint):
-    """
-    This constrained repeats EOS token if it was generated on the previous step.
-    Args:
-        prefix: The prefix of the sequence.
-        eos_token_id: The id of the end of sentence token.
-        pad_token_id: The id of the padding token.
-        min_logits: The minimum logits that can be generated. Default: -1e6.
-    """
-
-    def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
-        self.eos_token_id = eos_token_id
-        self.min_logits = min_logits
-        self.past_tokens = None
-
-        self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
-
-    def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
-        if self.past_tokens is not None:
-            mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
-            logits += self.min_logits * mask
-            logits[mask[:, 0], self.eos_token_id] = 0
-
-        if tokens_id is not None:
-            self.past_tokens = tokens_id
-            self.wait_until_starting -= 1
-
-        return logits

+ 9 - 1
src/petals/utils/misc.py

@@ -5,5 +5,13 @@ DUMMY = torch.empty(0)  # dummy tensor that replaces empty prompt or adapter par
 DUMMY_INT64 = torch.empty(0, dtype=torch.int64)
 
 
-def is_dummy(tensor: torch.Tensor):
+def is_dummy(tensor: torch.Tensor) -> bool:
     return tensor.numel() == 0
+
+
+def docstring_from(source):
+    def add_docstring(dest):
+        dest.__doc__ = source.__doc__
+        return dest
+
+    return add_docstring

+ 83 - 122
tests/test_full_model.py

@@ -3,7 +3,6 @@ import pytest
 import torch
 import transformers
 from hivemind import get_logger
-from transformers.generation import BeamSearchScorer, GenerationMixin as HfGenerationMixin
 
 from petals import AutoDistributedModelForCausalLM
 from test_utils import *
@@ -17,18 +16,29 @@ def tokenizer():
     return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
 
 
+@pytest.fixture
+def model():
+    return AutoDistributedModelForCausalLM.from_pretrained(
+        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
+    )
+
+
+@pytest.fixture
+def ref_model():
+    return transformers.AutoModelForCausalLM.from_pretrained(
+        REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
+    )
+
+
 @pytest.mark.forked
 @pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
 @pytest.mark.parametrize("pass_empty_tensors", (True, False))
-def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
-    model = AutoDistributedModelForCausalLM.from_pretrained(
-        MODEL_NAME,
-        initial_peers=INITIAL_PEERS,
-        torch_dtype=torch.float32,
-        active_adapter=ADAPTER_NAME if use_peft else None,
-    )
-    config = model.config
-    assert len(model.transformer.h) == model.config.num_hidden_layers
+def test_full_model_exact_match(tokenizer, model, ref_model, use_peft, pass_empty_tensors, atol=1e-3):
+    if use_peft:
+        model.config.active_adapter = ADAPTER_NAME
+
+        ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
+        ref_model.train(False)
 
     test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
 
@@ -42,7 +52,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
         recurrent_outputs = []
         with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
             if pass_empty_tensors:
-                recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
+                recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
 
             for t in range(embs.shape[1]):
                 if t == 4:
@@ -53,52 +63,39 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
                     recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
 
                 if t == 2 and pass_empty_tensors:
-                    recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
-                    recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
+                    recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
+                    recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
 
         recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
         recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
         recurrent_outputs = model.lm_head(recurrent_outputs)
-        assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
-        logger.info("Inference is consistent with forward")
-
-        del model, embs, recurrent_outputs
-
-        if REF_NAME:
-            ref_model = transformers.AutoModelForCausalLM.from_pretrained(
-                REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
-            )
-            if use_peft:
-                ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
-                ref_model.train(False)
-            if config.vocab_size < ref_model.config.vocab_size:
-                ref_model.resize_token_embeddings(config.vocab_size)
-                logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
-
-            dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
-            # note: this creates a dummy mask to make the test compatible with older transformer versions
-            # prior to https://github.com/huggingface/transformers/pull/17837
-            ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
-            assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
-            logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
-            del ref_model, ref_outputs, dummy_mask
-        else:
-            logger.warning("Did not test exact match with local model: REF_NAME environment variable is not set")
-            assert False
+        assert torch.allclose(
+            recurrent_outputs, parallel_outputs, rtol=0, atol=atol
+        ), "Inference differs from forward pass"
+
+        ref_outputs = ref_model.forward(test_inputs).logits.float()
+        assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol), "Outputs are not identical to HF"
+
+
+def make_generate_calls(model, inputs, *, max_new_tokens, multiple_calls=False, **kwargs):
+    if not multiple_calls:
+        return model.generate(inputs, max_new_tokens=max_new_tokens, **kwargs)
+
+    with model.inference_session(max_length=inputs.shape[1] + max_new_tokens) as sess:
+        return torch.cat(
+            [
+                # Sessions provided both explicitly and implicitly should work
+                model.generate(inputs, max_new_tokens=1, **kwargs, session=sess),
+                model.generate(None, max_new_tokens=max_new_tokens - 2, **kwargs),
+                model.generate(None, max_new_tokens=1, **kwargs),
+            ],
+            dim=1,
+        )
 
 
 @pytest.mark.forked
-def test_greedy_generation(tokenizer, max_new_tokens=4):
-    model = AutoDistributedModelForCausalLM.from_pretrained(
-        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
-    )
-    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    remote_outputs = model.generate(
-        inputs,
-        max_new_tokens=max_new_tokens,
-    )
-    hf_outputs = HfGenerationMixin.greedy_search(model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens)
-    assert torch.allclose(remote_outputs, hf_outputs), "Greedy search results are not identical to HF"
+def test_greedy_generation(tokenizer, model, ref_model, max_new_tokens=4):
+    inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
     if tokenizer.pad_token_id is None:
         tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -106,85 +103,49 @@ def test_greedy_generation(tokenizer, max_new_tokens=4):
         "input_ids"
     ]
 
-    remote_outputs_batch = model.generate(
-        inputs_batch,
-        max_new_tokens=max_new_tokens,
-    )
-    hf_outputs_batch = HfGenerationMixin.greedy_search(
-        model, input_ids=inputs_batch, max_length=inputs_batch.size(1) + max_new_tokens
-    )
-    assert torch.allclose(
-        remote_outputs_batch, hf_outputs_batch
-    ), "Greedy search results are not identical to HF in multibatch mode"
+    options = dict(max_new_tokens=max_new_tokens, do_sample=False)
+    for multiple_calls in [False, True]:
+        for inputs in [inputs_single, inputs_batch]:
+            outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
+            ref_outputs = ref_model.generate(inputs, **options)
+            assert torch.allclose(
+                outputs, ref_outputs
+            ), f"Greedy generation is not identical to HF with {multiple_calls=}, {inputs.shape=}"
 
 
 @pytest.mark.forked
-@pytest.mark.parametrize("sampling_options", [dict(), dict(temperature=100.0), dict(top_k=5), dict(top_p=0.9)])
-@pytest.mark.skip("Sampling is currently not consistent with outputs from Transformers")
-def test_sampling(tokenizer, sampling_options, max_new_tokens=4):
-    torch.manual_seed(0)
-
-    model = AutoDistributedModelForCausalLM.from_pretrained(
-        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
-    )
-    logits_warper = HfGenerationMixin._get_logits_warper(model, num_beams=1, **sampling_options)
-    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
-    with torch.random.fork_rng():
-        remote_outputs = model.generate(
-            inputs,
-            max_new_tokens=max_new_tokens,
-            do_sample=True,
-            **sampling_options,
-        )
-    with torch.random.fork_rng():
-        hf_outputs = HfGenerationMixin.sample(
-            model, input_ids=inputs, max_length=inputs.size(1) + max_new_tokens, logits_warper=logits_warper
-        )
-    assert torch.allclose(remote_outputs, hf_outputs), "Sampling results are not identical to HF"
+def test_sampling(tokenizer, model, ref_model, max_new_tokens=10):
+    inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
 
+    if tokenizer.pad_token_id is None:
+        tokenizer.pad_token_id = tokenizer.eos_token_id
     inputs_batch = tokenizer(["A cat sat on a mat", "A dog sat on a mat"], return_tensors="pt", padding=True)[
         "input_ids"
     ]
-    with torch.random.fork_rng():
-        remote_outputs_batch = model.generate(
-            inputs_batch,
-            max_new_tokens=max_new_tokens,
-            do_sample=True,
-            **sampling_options,
-        )
-    with torch.random.fork_rng():
-        hf_outputs_batch = HfGenerationMixin.sample(
-            model,
-            input_ids=inputs_batch,
-            max_length=inputs_batch.size(1) + max_new_tokens,
-            logits_warper=logits_warper,
-        )
-    assert torch.allclose(
-        remote_outputs_batch, hf_outputs_batch
-    ), "Sampling results are not identical to HF in multibatch mode"
+
+    for options in [
+        dict(do_sample=True, temperature=0.5, top_k=5, top_p=0.9),
+        dict(do_sample=True, temperature=0.5, repetition_penalty=1.2),
+    ]:
+        options.update(max_new_tokens=max_new_tokens)
+        for multiple_calls in [False, True]:
+            for inputs in [inputs_single, inputs_batch]:
+                torch.manual_seed(0)
+                outputs = make_generate_calls(model, inputs, multiple_calls=multiple_calls, **options)
+
+                torch.manual_seed(0)
+                ref_outputs = ref_model.generate(inputs, **options)
+
+                assert torch.allclose(
+                    outputs, ref_outputs
+                ), f"Sampling is not identical to HF with {options=}, {multiple_calls=}, {inputs.shape=}"
 
 
 @pytest.mark.forked
-def test_beam_search_generation(tokenizer, max_new_tokens=4, num_beams=2):
-    model = AutoDistributedModelForCausalLM.from_pretrained(
-        MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32
-    )
-    text = "A cat sat on a mat"
-    inputs = tokenizer(text, return_tensors="pt")["input_ids"]
-    remote_outputs = model.generate(
-        inputs,
-        max_new_tokens=max_new_tokens,
-        num_beams=num_beams,
-    )
-    beam_scorer = BeamSearchScorer(
-        batch_size=inputs.size(0),
-        num_beams=num_beams,
-        device=inputs.device,
-        length_penalty=0,
-        do_early_stopping=False,
-    )
-    hf_inputs = tokenizer([text] * 2, return_tensors="pt")["input_ids"]
-    hf_outputs = HfGenerationMixin.beam_search(
-        model, input_ids=hf_inputs, max_length=inputs.size(1) + max_new_tokens, beam_scorer=beam_scorer
-    )
-    assert torch.allclose(remote_outputs, hf_outputs), "Beam search results are not identical to HF"
+def test_beam_search_generation(tokenizer, model, ref_model, max_new_tokens=4, num_beams=5):
+    inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
+
+    options = dict(max_new_tokens=max_new_tokens, num_beams=num_beams, do_sample=False)
+    outputs = make_generate_calls(model, inputs, **options)
+    ref_outputs = ref_model.generate(inputs, **options)
+    assert torch.allclose(outputs, ref_outputs), f"Beam search results are not identical to HF"